|
import argparse |
|
import datetime |
|
import logging |
|
from pathlib import Path |
|
from typing import Any, List, Literal, Tuple |
|
|
|
from pydantic import BaseModel, ValidationInfo, field_validator |
|
|
|
|
|
class Args(BaseModel): |
|
|
|
model_path: Path |
|
model_name: str |
|
model_type: Literal["i2v", "t2v", "i2vFlow"] |
|
training_type: Literal["lora", "sft", "controlnet"] = "lora" |
|
additional_save_blocks: List[str] | None = None |
|
depth_ckpt_path: str |
|
|
|
|
|
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now())) |
|
report_to: Literal["tensorboard", "wandb", "all"] | None = None |
|
tracker_name: str = "finetrainer-cogvideo" |
|
run_name: str = "CogVideoX" |
|
|
|
|
|
data_root: Path |
|
caption_column: Path |
|
image_column: Path | None = None |
|
video_column: Path |
|
|
|
|
|
resume_from_checkpoint: Path | None = None |
|
|
|
seed: int | None = None |
|
train_epochs: int |
|
train_steps: int | None = None |
|
checkpointing_steps: int = 200 |
|
checkpointing_limit: int = 10 |
|
|
|
batch_size: int |
|
gradient_accumulation_steps: int = 1 |
|
|
|
train_resolution: Tuple[int, int, int] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mixed_precision: Literal["no", "fp16", "bf16"] |
|
|
|
learning_rate: float = 2e-5 |
|
optimizer: str = "adamw" |
|
beta1: float = 0.9 |
|
beta2: float = 0.95 |
|
beta3: float = 0.98 |
|
epsilon: float = 1e-8 |
|
weight_decay: float = 1e-4 |
|
max_grad_norm: float = 1.0 |
|
|
|
lr_scheduler: str = "constant_with_warmup" |
|
lr_warmup_steps: int = 100 |
|
lr_num_cycles: int = 1 |
|
lr_power: float = 1.0 |
|
|
|
num_workers: int = 8 |
|
pin_memory: bool = True |
|
|
|
gradient_checkpointing: bool = True |
|
enable_slicing: bool = True |
|
enable_tiling: bool = True |
|
nccl_timeout: int = 1800 |
|
|
|
|
|
rank: int = 128 |
|
lora_alpha: int = 64 |
|
target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"] |
|
|
|
|
|
do_validation: bool = False |
|
validation_steps: int | None |
|
validation_dir: Path | None |
|
validation_prompts: str | None |
|
validation_images: str | None |
|
validation_videos: str | None |
|
gen_fps: int = 15 |
|
max_scene: int = 8 |
|
|
|
|
|
controlnet_transformer_num_layers: int = 8 |
|
controlnet_input_channels: int = 16 |
|
controlnet_weights: float = 1.0 |
|
controlnet_guidance_start: float = 0.0 |
|
controlnet_guidance_end: float = 1.0 |
|
controlnet_out_proj_dim_factor: int = 64 |
|
controlnet_out_proj_zero_init: bool = True |
|
enable_time_sampling: bool = True |
|
time_sampling_type: str = 'truncated_normal' |
|
time_sampling_mean: float = 0.95 |
|
time_sampling_std: float = 0.1 |
|
use_valid_mask: bool = False |
|
notextinflow: bool = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("image_column") |
|
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None: |
|
values = info.data |
|
if values.get("model_type") == "i2v" and not v: |
|
logging.warning( |
|
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images." |
|
) |
|
return v |
|
|
|
@field_validator("validation_dir", "validation_prompts") |
|
def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any: |
|
values = info.data |
|
if values.get("do_validation") and not v: |
|
field_name = info.field_name |
|
raise ValueError(f"{field_name} must be specified when do_validation is True") |
|
return v |
|
|
|
@field_validator("validation_images") |
|
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None: |
|
values = info.data |
|
if values.get("do_validation") and values.get("model_type") == "i2v" and not v: |
|
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v") |
|
return v |
|
|
|
@field_validator("validation_videos") |
|
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None: |
|
values = info.data |
|
if values.get("do_validation") and values.get("model_type") == "v2v" and not v: |
|
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v") |
|
return v |
|
|
|
@field_validator("validation_steps") |
|
def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None: |
|
values = info.data |
|
if values.get("do_validation"): |
|
if v is None: |
|
raise ValueError("validation_steps must be specified when do_validation is True") |
|
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0: |
|
raise ValueError("validation_steps must be a multiple of checkpointing_steps") |
|
return v |
|
|
|
@field_validator("train_resolution") |
|
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str: |
|
try: |
|
frames, height, width = v |
|
|
|
|
|
if (frames - 1) % 8 != 0: |
|
raise ValueError("Number of frames - 1 must be a multiple of 8") |
|
|
|
|
|
model_name = info.data.get("model_name", "") |
|
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]: |
|
if (height, width) != (480, 720): |
|
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720") |
|
|
|
return v |
|
|
|
except ValueError as e: |
|
if ( |
|
str(e) == "not enough values to unpack (expected 3, got 0)" |
|
or str(e) == "invalid literal for int() with base 10" |
|
): |
|
raise ValueError("train_resolution must be in format 'frames x height x width'") |
|
raise e |
|
|
|
@field_validator("mixed_precision") |
|
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str: |
|
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower(): |
|
logging.warning( |
|
"All CogVideoX models except cogvideox-2b were trained with bfloat16. " |
|
"Using fp16 precision may lead to training instability." |
|
) |
|
return v |
|
|
|
@classmethod |
|
def parse_args(cls): |
|
"""Parse command line arguments and return Args instance""" |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--model_path", type=str, required=True) |
|
parser.add_argument("--model_name", type=str, required=True) |
|
parser.add_argument("--model_type", type=str, required=True) |
|
parser.add_argument("--depth_ckpt_path", type=str, required=False, default="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth", help="Path to the checkpoint of the depth estimation networks") |
|
parser.add_argument("--training_type", type=str, required=True) |
|
parser.add_argument("--additional_save_blocks", type=str, required=False, default=None) |
|
parser.add_argument("--output_dir", type=str, required=True) |
|
parser.add_argument("--data_root", type=str, required=True) |
|
parser.add_argument("--caption_column", type=str, required=True) |
|
parser.add_argument("--video_column", type=str, required=True) |
|
parser.add_argument("--train_resolution", type=str, required=True) |
|
parser.add_argument("--report_to", type=str, required=True) |
|
parser.add_argument("--run_name", type=str, required=False, default='CogVideoX') |
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42) |
|
parser.add_argument("--train_epochs", type=int, default=10) |
|
parser.add_argument("--train_steps", type=int, default=None) |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
|
parser.add_argument("--batch_size", type=int, default=1) |
|
parser.add_argument("--learning_rate", type=float, default=2e-5) |
|
parser.add_argument("--optimizer", type=str, default="adamw") |
|
parser.add_argument("--beta1", type=float, default=0.9) |
|
parser.add_argument("--beta2", type=float, default=0.95) |
|
parser.add_argument("--beta3", type=float, default=0.98) |
|
parser.add_argument("--epsilon", type=float, default=1e-8) |
|
parser.add_argument("--weight_decay", type=float, default=1e-4) |
|
parser.add_argument("--max_grad_norm", type=float, default=1.0) |
|
|
|
|
|
parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup") |
|
parser.add_argument("--lr_warmup_steps", type=int, default=100) |
|
parser.add_argument("--lr_num_cycles", type=int, default=1) |
|
parser.add_argument("--lr_power", type=float, default=1.0) |
|
|
|
|
|
parser.add_argument("--num_workers", type=int, default=8) |
|
parser.add_argument("--pin_memory", type=bool, default=True) |
|
parser.add_argument("--image_column", type=str, default=None) |
|
|
|
|
|
parser.add_argument("--mixed_precision", type=str, default="no") |
|
parser.add_argument("--gradient_checkpointing", type=bool, default=True) |
|
parser.add_argument("--enable_slicing", type=bool, default=True) |
|
parser.add_argument("--enable_tiling", type=bool, default=True) |
|
parser.add_argument("--nccl_timeout", type=int, default=1800) |
|
|
|
|
|
parser.add_argument("--rank", type=int, default=128) |
|
parser.add_argument("--lora_alpha", type=int, default=64) |
|
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]) |
|
|
|
|
|
parser.add_argument("--checkpointing_steps", type=int, default=200) |
|
parser.add_argument("--checkpointing_limit", type=int, default=10) |
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None) |
|
|
|
|
|
parser.add_argument("--do_validation", type=lambda x: x.lower() == 'true', default=False) |
|
parser.add_argument("--validation_steps", type=int, default=None) |
|
parser.add_argument("--validation_dir", type=str, default=None) |
|
parser.add_argument("--validation_prompts", type=str, default=None) |
|
parser.add_argument("--validation_images", type=str, default=None) |
|
parser.add_argument("--validation_videos", type=str, default=None) |
|
parser.add_argument("--gen_fps", type=int, default=15) |
|
parser.add_argument("--max_scene", type=int, default=8) |
|
|
|
|
|
parser.add_argument("--controlnet_transformer_num_layers", type=int, default=8) |
|
parser.add_argument("--controlnet_input_channels", type=int, default=16) |
|
parser.add_argument("--controlnet_weights", type=float, default=1.0) |
|
parser.add_argument("--controlnet_guidance_start", type=float, default=0.0) |
|
parser.add_argument("--controlnet_guidance_end", type=float, default=1.0) |
|
parser.add_argument("--controlnet_out_proj_dim_factor", type=int, default=64) |
|
parser.add_argument("--controlnet_out_proj_zero_init", type=bool, default=True) |
|
parser.add_argument("--enable_time_sampling", type=bool, default=True) |
|
|
|
parser.add_argument("--time_sampling_type", type=str, default='truncated_normal') |
|
parser.add_argument("--time_sampling_mean", type=float, default=0.95) |
|
parser.add_argument("--time_sampling_std", type=float, default=0.1) |
|
parser.add_argument("--use_valid_mask", type=bool, default=False) |
|
parser.add_argument("--notextinflow", type=bool, default=False) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
frames, height, width = args.train_resolution.split("x") |
|
args.train_resolution = (int(frames), int(height), int(width)) |
|
|
|
if args.additional_save_blocks is not None: |
|
args.additional_save_blocks = args.additional_save_blocks.split(',') |
|
if not args.training_type == 'lora': |
|
|
|
assert args.additional_save_blocks is None |
|
|
|
return cls(**vars(args)) |
|
|