import argparse import datetime import logging import inspect import math import os import random import gc import copy from typing import Dict, Optional, Tuple from omegaconf import OmegaConf import cv2 import torch import torch.nn.functional as F import torch.utils.checkpoint import torchvision.transforms as T import diffusers import transformers from torchvision import transforms from tqdm.auto import tqdm from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from models.unet_3d_condition import UNet3DConditionModel from diffusers.models import AutoencoderKL from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, TextToVideoSDPipeline from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, export_to_video from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention_processor import AttnProcessor2_0, Attention from diffusers.models.attention import BasicTransformerBlock from transformers import CLIPTextModel, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPEncoder from utils.dataset import VideoJsonDataset, SingleVideoDataset, \ ImageDataset, VideoFolderDataset, CachedDataset from einops import rearrange, repeat from utils.lora import ( extract_lora_ups_down, inject_trainable_lora, inject_trainable_lora_extended, save_lora_weight, train_patch_pipe, monkeypatch_or_replace_lora, monkeypatch_or_replace_lora_extended ) already_printed_trainables = False # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") logger = get_logger(__name__, log_level="INFO") def create_logging(logging, logger, accelerator): logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) def accelerate_set_verbose(accelerator): if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() def get_train_dataset(dataset_types, train_data, tokenizer): train_datasets = [] # Loop through all available datasets, get the name, then add to list of data to process. for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]: for dataset in dataset_types: if dataset == DataSet.__getname__(): train_datasets.append(DataSet(**train_data, tokenizer=tokenizer)) if len(train_datasets) > 0: return train_datasets else: raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'") def extend_datasets(datasets, dataset_items, extend=False): biggest_data_len = max(x.__len__() for x in datasets) extended = [] for dataset in datasets: if dataset.__len__() == 0: del dataset continue if dataset.__len__() < biggest_data_len: for item in dataset_items: if extend and item not in extended and hasattr(dataset, item): print(f"Extending {item}") value = getattr(dataset, item) value *= biggest_data_len value = value[:biggest_data_len] setattr(dataset, item, value) print(f"New {item} dataset length: {dataset.__len__()}") extended.append(item) def export_to_video(video_frames, output_video_path, fps): fourcc = cv2.VideoWriter_fourcc(*"mp4v") h, w, _ = video_frames[0].shape video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) def create_output_folders(output_dir, config): now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") out_dir = os.path.join(output_dir, f"train_{now}") os.makedirs(out_dir, exist_ok=True) os.makedirs(f"{out_dir}/samples", exist_ok=True) OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) return out_dir def load_primary_models(pretrained_model_path): noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") return noise_scheduler, tokenizer, text_encoder, vae, unet def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable): unet._set_gradient_checkpointing(value=unet_enable) text_encoder._set_gradient_checkpointing(CLIPEncoder, value=text_enable) def freeze_models(models_to_freeze): for model in models_to_freeze: if model is not None: model.requires_grad_(False) def is_attn(name): return ('attn1' or 'attn2' == name.split('.')[-1]) def set_processors(attentions): for attn in attentions: attn.set_processor(AttnProcessor2_0()) def set_torch_2_attn(unet): optim_count = 0 for name, module in unet.named_modules(): if is_attn(name): if isinstance(module, torch.nn.ModuleList): for m in module: if isinstance(m, BasicTransformerBlock): set_processors([m.attn1, m.attn2]) optim_count += 1 if optim_count > 0: print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet): try: is_torch_2 = hasattr(F, 'scaled_dot_product_attention') if enable_xformers_memory_efficient_attention and not is_torch_2: if is_xformers_available(): from xformers.ops import MemoryEfficientAttentionFlashAttentionOp unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) else: raise ValueError("xformers is not available. Make sure it is installed correctly") if enable_torch_2_attn and is_torch_2: set_torch_2_attn(unet) except: print("Could not enable memory efficient attention for xformers or Torch 2.0.") def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, lora_path='', r=16): injector = ( inject_trainable_lora if not is_extended else inject_trainable_lora_extended ) params = None negation = None if os.path.exists(lora_path): try: for f in os.listdir(lora_path): if f.endswith('.pt'): lora_file = os.path.join(lora_path, f) if 'text_encoder' in f and isinstance(model, CLIPTextModel): monkeypatch_or_replace_lora( model, torch.load(lora_file), target_replace_module=replace_modules, r=r ) print("Successfully loaded Text Encoder LoRa.") if 'unet' in f and isinstance(model, UNet3DConditionModel): monkeypatch_or_replace_lora_extended( model, torch.load(lora_file), target_replace_module=replace_modules, r=r ) print("Successfully loaded UNET LoRa.") except Exception as e: print(e) print("Could not load LoRAs. Injecting new ones instead...") if use_lora: REPLACE_MODULES = replace_modules injector_args = { "model": model, "target_replace_module": REPLACE_MODULES, "r": r } if not is_extended: injector_args['dropout_p'] = dropout params, negation = injector(**injector_args) for _up, _down in extract_lora_ups_down( model, target_replace_module=REPLACE_MODULES): if all(x is not None for x in [_up, _down]): print(f"Lora successfully injected into {model.__class__.__name__}.") break return params, negation def save_lora(model, name, condition, replace_modules, step, save_path): if condition and replace_modules is not None: save_path = f"{save_path}/{step}_{name}.pt" save_lora_weight(model, save_path, replace_modules) def handle_lora_save( use_unet_lora, use_text_lora, model, save_path, checkpoint_step, unet_target_modules, text_encoder_target_modules ): save_path = f"{save_path}/lora" os.makedirs(save_path, exist_ok=True) save_lora( model.unet, 'unet', use_unet_lora, unet_target_modules, checkpoint_step, save_path, ) save_lora( model.text_encoder, 'text_encoder', use_text_lora, text_encoder_target_modules, checkpoint_step, save_path ) train_patch_pipe(model, use_unet_lora, use_text_lora) def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): return { "model": model, "condition": condition, 'extra_params': extra_params, 'is_lora': is_lora, "negation": negation } def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None): params = { "name": name, "params": params, "lr": lr } if extra_params is not None: for k, v in extra_params.items(): params[k] = v return params def negate_params(name, negation): # We have to do this if we are co-training with LoRA. # This ensures that parameter groups aren't duplicated. if negation is None: return False for n in negation: if n in name and 'temp' not in name: return True return False def create_optimizer_params(model_list, lr): import itertools optimizer_params = [] for optim in model_list: model, condition, extra_params, is_lora, negation = optim.values() # Check if we are doing LoRA training. if is_lora and condition: params = create_optim_params( params=itertools.chain(*model), extra_params=extra_params ) optimizer_params.append(params) continue # If this is true, we can train it. if condition: for n, p in model.named_parameters(): should_negate = 'lora' in n if should_negate: continue params = create_optim_params(n, p, lr, extra_params) optimizer_params.append(params) return optimizer_params def get_optimizer(use_8bit_adam): if use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" ) return bnb.optim.AdamW8bit else: return torch.optim.AdamW def is_mixed_precision(accelerator): weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 return weight_dtype def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): for model in model_list: if model is not None: model.to(accelerator.device, dtype=weight_dtype) def handle_cache_latents( should_cache, output_dir, train_dataloader, train_batch_size, vae, cached_latent_dir=None ): # Cache latents by storing them in VRAM. # Speeds up training and saves memory by not encoding during the train loop. if not should_cache: return None vae.to('cuda', dtype=torch.float16) vae.enable_slicing() cached_latent_dir = ( os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None ) if cached_latent_dir is None: cache_save_dir = f"{output_dir}/cached_latents" os.makedirs(cache_save_dir, exist_ok=True) for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): save_name = f"cached_{i}" full_out_path = f"{cache_save_dir}/{save_name}.pt" pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae) for k, v in batch.items(): batch[k] = v[0] torch.save(batch, full_out_path) del pixel_values del batch # We do this to avoid fragmentation from casting latents between devices. torch.cuda.empty_cache() else: cache_save_dir = cached_latent_dir return torch.utils.data.DataLoader( CachedDataset(cache_dir=cache_save_dir), batch_size=train_batch_size, shuffle=True, num_workers=0 ) def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None): global already_printed_trainables # This can most definitely be refactored :-) unfrozen_params = 0 if trainable_modules is not None: for name, module in model.named_modules(): for tm in tuple(trainable_modules): if tm == 'all': model.requires_grad_(is_enabled) unfrozen_params =len(list(model.parameters())) break if tm in name and 'lora' not in name: for m in module.parameters(): m.requires_grad_(is_enabled) if is_enabled: unfrozen_params +=1 if unfrozen_params > 0 and not already_printed_trainables: already_printed_trainables = True print(f"{unfrozen_params} params have been unfrozen for training.") def tensor_to_vae_latent(t, vae): video_length = t.shape[1] t = rearrange(t, "b f c h w -> (b f) c h w") latents = vae.encode(t).latent_dist.sample() latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) latents = latents * 0.18215 return latents def sample_noise(latents, noise_strength, use_offset_noise): b ,c, f, *_ = latents.shape noise_latents = torch.randn_like(latents, device=latents.device) offset_noise = None if use_offset_noise: offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device) noise_latents = noise_latents + noise_strength * offset_noise return noise_latents def should_sample(global_step, validation_steps, validation_data): return (global_step % validation_steps == 0 or global_step == 1) \ and validation_data.sample_preview def save_pipe( path, global_step, accelerator, unet, text_encoder, vae, output_dir, use_unet_lora, use_text_lora, unet_target_replace_module=None, text_target_replace_module=None, is_checkpoint=False, ): if is_checkpoint: save_path = os.path.join(output_dir, f"checkpoint-{global_step}") os.makedirs(save_path, exist_ok=True) else: save_path = output_dir # Save the dtypes so we can continue training at the same precision. u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled. unet_out = copy.deepcopy(accelerator.unwrap_model(unet, keep_fp32_wrapper=False)) text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)) pipeline = TextToVideoSDPipeline.from_pretrained( path, unet=unet_out, text_encoder=text_encoder_out, vae=vae, ).to(torch_dtype=torch.float16) handle_lora_save( use_unet_lora, use_text_lora, pipeline, output_dir, global_step, unet_target_replace_module, text_target_replace_module ) pipeline.save_pretrained(save_path) if is_checkpoint: unet, text_encoder = accelerator.prepare(unet, text_encoder) models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)] [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back] logger.info(f"Saved model at {save_path} on step {global_step}") del pipeline del unet_out del text_encoder_out torch.cuda.empty_cache() gc.collect() def replace_prompt(prompt, token, wlist): for w in wlist: if w in prompt: return prompt.replace(w, token) return prompt def main( pretrained_model_path: str, output_dir: str, train_data: Dict, validation_data: Dict, dataset_types: Tuple[str] = ('json'), validation_steps: int = 100, trainable_modules: Tuple[str] = ("attn1", "attn2"), trainable_text_modules: Tuple[str] = ("all"), extra_unet_params = None, extra_text_encoder_params = None, train_batch_size: int = 1, max_train_steps: int = 500, learning_rate: float = 5e-5, scale_lr: bool = False, lr_scheduler: str = "constant", lr_warmup_steps: int = 0, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_weight_decay: float = 1e-2, adam_epsilon: float = 1e-08, max_grad_norm: float = 1.0, gradient_accumulation_steps: int = 1, gradient_checkpointing: bool = False, text_encoder_gradient_checkpointing: bool = False, checkpointing_steps: int = 500, resume_from_checkpoint: Optional[str] = None, mixed_precision: Optional[str] = "fp16", use_8bit_adam: bool = False, enable_xformers_memory_efficient_attention: bool = True, enable_torch_2_attn: bool = False, seed: Optional[int] = None, train_text_encoder: bool = False, use_offset_noise: bool = False, offset_noise_strength: float = 0.1, extend_dataset: bool = False, cache_latents: bool = False, cached_latent_dir = None, use_unet_lora: bool = False, use_text_lora: bool = False, unet_lora_modules: Tuple[str] = ["ResnetBlock2D"], text_encoder_lora_modules: Tuple[str] = ["CLIPEncoderLayer"], lora_rank: int = 16, lora_path: str = '', **kwargs ): *_, config = inspect.getargvalues(inspect.currentframe()) accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision, log_with="tensorboard", logging_dir=output_dir ) # Make one log on every process with the configuration for debugging. create_logging(logging, logger, accelerator) # Initialize accelerate, transformers, and diffusers warnings accelerate_set_verbose(accelerator) # If passed along, set the training seed now. if seed is not None: set_seed(seed) # Handle the output folder creation if accelerator.is_main_process: output_dir = create_output_folders(output_dir, config) # Load scheduler, tokenizer and models. noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path) # Freeze any necessary models freeze_models([vae, text_encoder, unet]) # Enable xformers if available handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet) if scale_lr: learning_rate = ( learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes ) # Initialize the optimizer optimizer_cls = get_optimizer(use_8bit_adam) # Use LoRA if enabled. unet_lora_params, unet_negation = inject_lora( use_unet_lora, unet, unet_lora_modules, is_extended=True, r=lora_rank, lora_path=lora_path ) text_encoder_lora_params, text_encoder_negation = inject_lora( use_text_lora, text_encoder, text_encoder_lora_modules, r=lora_rank, lora_path=lora_path ) # Create parameters to optimize over with a condition (if "condition" is true, optimize it) optim_params = [ param_optim(unet, trainable_modules is not None, extra_params=extra_unet_params, negation=unet_negation), param_optim(text_encoder, train_text_encoder and not use_text_lora, extra_params=extra_text_encoder_params, negation=text_encoder_negation ), param_optim(text_encoder_lora_params, use_text_lora, is_lora=True, extra_params={"lr": 1e-5}), param_optim(unet_lora_params, use_unet_lora, is_lora=True, extra_params={"lr": 1e-5}) ] params = create_optimizer_params(optim_params, learning_rate) # Create Optimizer optimizer = optimizer_cls( params, lr=learning_rate, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon, ) # Scheduler lr_scheduler = get_scheduler( lr_scheduler, optimizer=optimizer, num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, num_training_steps=max_train_steps * gradient_accumulation_steps, ) # Get the training dataset based on types (json, single_video, image) train_datasets = get_train_dataset(dataset_types, train_data, tokenizer) # Extend datasets that are less than the greatest one. This allows for more balanced training. attrs = ['train_data', 'frames', 'image_dir', 'video_files'] extend_datasets(train_datasets, attrs, extend=extend_dataset) # Process one dataset if len(train_datasets) == 1: train_dataset = train_datasets[0] # Process many datasets else: train_dataset = torch.utils.data.ConcatDataset(train_datasets) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True ) # Latents caching cached_data_loader = handle_cache_latents( cache_latents, output_dir, train_dataloader, train_batch_size, vae, cached_latent_dir ) if cached_data_loader is not None: train_dataloader = cached_data_loader # Prepare everything with our `accelerator`. unet, optimizer,train_dataloader, lr_scheduler, text_encoder = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler, text_encoder ) # Use Gradient Checkpointing if enabled. unet_and_text_g_c( unet, text_encoder, gradient_checkpointing, text_encoder_gradient_checkpointing ) # Enable VAE slicing to save memory. vae.enable_slicing() # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = is_mixed_precision(accelerator) # Move text encoders, and VAE to GPU models_to_cast = [text_encoder, vae] cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) # Afterwards we recalculate our number of training epochs num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("text2video-fine-tune") # Train! total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_train_steps}") global_step = 0 first_epoch = 0 # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") def finetune_unet(batch, train_encoder=False): # Check if we are training the text encoder text_trainable = (train_text_encoder or use_text_lora) # Unfreeze UNET Layers if global_step == 0: already_printed_trainables = False unet.train() handle_trainable_modules( unet, trainable_modules, is_enabled=True, negation=unet_negation ) # Convert videos to latent space pixel_values = batch["pixel_values"] if not cache_latents: latents = tensor_to_vae_latent(pixel_values, vae) else: latents = pixel_values # Get video length video_length = latents.shape[2] # Sample noise that we'll add to the latents noise = sample_noise(latents, offset_noise_strength, use_offset_noise) bsz = latents.shape[0] # Sample a random timestep for each video timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Enable text encoder training if text_trainable: text_encoder.train() if use_text_lora: text_encoder.text_model.embeddings.requires_grad_(True) if global_step == 0 and train_text_encoder: handle_trainable_modules( text_encoder, trainable_modules=trainable_text_modules, negation=text_encoder_negation ) cast_to_gpu_and_type([text_encoder], accelerator, torch.float32) # Fixes gradient checkpointing training. # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb if gradient_checkpointing or text_encoder_gradient_checkpointing: unet.eval() text_encoder.eval() # Encode text embeddings token_ids = batch['prompt_ids'] encoder_hidden_states = text_encoder(token_ids)[0] # Get the target for loss depending on the prediction type if noise_scheduler.prediction_type == "epsilon": target = noise elif noise_scheduler.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}") # Here we do two passes for video and text training. # If we are on the second iteration of the loop, get one frame. # This allows us to train text information only on the spatial layers. losses = [] should_truncate_video = (video_length > 1 and text_trainable) # We detach the encoder hidden states for the first pass (video frames > 1) # Then we make a clone of the initial state to ensure we can train it in the loop. detached_encoder_state = encoder_hidden_states.clone().detach() trainable_encoder_state = encoder_hidden_states.clone() for i in range(2): should_detach = noisy_latents.shape[2] > 1 and i == 0 if should_truncate_video and i == 1: noisy_latents = noisy_latents[:,:,1,:,:].unsqueeze(2) target = target[:,:,1,:,:].unsqueeze(2) encoder_hidden_states = ( detached_encoder_state if should_detach else trainable_encoder_state ) model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") losses.append(loss) # This was most likely single frame training or a single image. if video_length == 1 and i == 0: break loss = losses[0] if len(losses) == 1 else losses[0] + losses[1] return loss, latents for epoch in range(first_epoch, num_train_epochs): train_loss = 0.0 for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if resume_from_checkpoint and epoch == first_epoch and step < resume_step: if step % gradient_accumulation_steps == 0: progress_bar.update(1) continue with accelerator.accumulate(unet) ,accelerator.accumulate(text_encoder): text_prompt = batch['text_prompt'][0] with accelerator.autocast(): loss, latents = finetune_unet(batch, train_encoder=train_text_encoder) # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() train_loss += avg_loss.item() / gradient_accumulation_steps # Backpropagate try: accelerator.backward(loss) params_to_clip = ( unet.parameters() if not train_text_encoder else list(unet.parameters()) + list(text_encoder.parameters()) ) accelerator.clip_grad_norm_(params_to_clip, max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) except Exception as e: print(f"An error has occured during backpropogation! {e}") continue # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 if global_step % checkpointing_steps == 0: save_pipe( pretrained_model_path, global_step, accelerator, unet, text_encoder, vae, output_dir, use_unet_lora, use_text_lora, unet_lora_modules, text_encoder_lora_modules, is_checkpoint=True ) if should_sample(global_step, validation_steps, validation_data): if global_step == 1: print("Performing validation prompt.") if accelerator.is_main_process: with accelerator.autocast(): unet.eval() text_encoder.eval() unet_and_text_g_c(unet, text_encoder, False, False) pipeline = TextToVideoSDPipeline.from_pretrained( pretrained_model_path, text_encoder=text_encoder, vae=vae, unet=unet ) diffusion_scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = diffusion_scheduler prompt = text_prompt if len(validation_data.prompt) <= 0 else validation_data.prompt curr_dataset_name = batch['dataset'] save_filename = f"{global_step}_dataset-{curr_dataset_name}_{prompt}" out_file = f"{output_dir}/samples/{save_filename}.mp4" with torch.no_grad(): video_frames = pipeline( prompt, width=validation_data.width, height=validation_data.height, num_frames=validation_data.num_frames, num_inference_steps=validation_data.num_inference_steps, guidance_scale=validation_data.guidance_scale ).frames export_to_video(video_frames, out_file, train_data.get('fps', 8)) del pipeline torch.cuda.empty_cache() logger.info(f"Saved a new sample to {out_file}") unet_and_text_g_c( unet, text_encoder, gradient_checkpointing, text_encoder_gradient_checkpointing ) logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} accelerator.log({"training_loss": loss.detach().item()}, step=step) progress_bar.set_postfix(**logs) if global_step >= max_train_steps: break # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: save_pipe( pretrained_model_path, global_step, accelerator, unet, text_encoder, vae, output_dir, use_unet_lora, use_text_lora, unet_lora_modules, text_encoder_lora_modules, is_checkpoint=False ) accelerator.end_training() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/my_config.yaml") args = parser.parse_args() main(**OmegaConf.load(args.config))