blur2vid / training /utils.py
ftaubner's picture
initial commit
7245cc5
import os
from typing import List, Optional, Union, Tuple
import torch
from transformers import T5EncoderModel, T5Tokenizer
import numpy as np
import cv2
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from accelerate.logging import get_logger
import tempfile
import argparse
import yaml
import shutil
logger = get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.")
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to the YAML config file."
)
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
args = argparse.Namespace(**config)
# Convert nested config dict to an argparse.Namespace for easier downstream usage
return args
def atomic_save(save_path, accelerator):
parent = os.path.dirname(save_path)
tmp_dir = tempfile.mkdtemp(dir=parent)
backup_dir = save_path + "_backup"
try:
# Save state into the temp directory
accelerator.save_state(tmp_dir)
# Backup existing save_path if it exists
if os.path.exists(save_path):
os.rename(save_path, backup_dir)
# Atomically move temp directory into place
os.rename(tmp_dir, save_path)
# Clean up the backup directory
if os.path.exists(backup_dir):
shutil.rmtree(backup_dir)
except Exception as e:
# Clean up temp directory on failure
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
# Restore from backup if replacement failed
if os.path.exists(backup_dir):
if os.path.exists(save_path):
shutil.rmtree(save_path)
os.rename(backup_dir, save_path)
raise e
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
if args.optimizer not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if args.optimizer.lower() == "adamw":
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "adam":
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
return optimizer
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def _get_t5_prompt_embeds(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = _get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
text_input_ids=text_input_ids,
)
return prompt_embeds
def compute_prompt_embeddings(
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
):
if requires_grad:
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
else:
with torch.no_grad():
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds
def save_frames_as_pngs(video_array,output_dir,
downsample_spatial=1, # e.g. 2 to halve width & height
downsample_temporal=1): # e.g. 2 to keep every 2nd frame
"""
Save each frame of a (T, H, W, C) numpy array as a PNG with no compression.
"""
assert video_array.ndim == 4 and video_array.shape[-1] == 3, \
"Expected (T, H, W, C=3) array"
assert video_array.dtype == np.uint8, "Expected uint8 array"
os.makedirs(output_dir, exist_ok=True)
# temporal downsample
frames = video_array[::downsample_temporal]
# compute spatially downsampled size
T, H, W, _ = frames.shape
new_size = (W // downsample_spatial, H // downsample_spatial)
# PNG compression param: 0 = no compression
png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0]
for idx, frame in enumerate(frames):
# frame is RGB; convert to BGR for OpenCV
bgr = frame[..., ::-1]
if downsample_spatial > 1:
bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST)
filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx))
success = cv2.imwrite(filename, bgr, png_params)
if not success:
raise RuntimeError("Failed to write frame ")