import os,copy import inspect from typing import Optional, List, Dict, Union import PIL import click from omegaconf import OmegaConf import torch import torch.utils.data import torch.nn.functional as F import torch.utils.checkpoint from accelerate import Accelerator from accelerate.utils import set_seed from diffusers import ( AutoencoderKL, DDPMScheduler, DDIMScheduler, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler from diffusers.utils.import_utils import is_xformers_available from diffusers.pipeline_utils import DiffusionPipeline from tqdm.auto import tqdm from transformers import AutoTokenizer, CLIPTextModel from einops import rearrange from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel from video_diffusion.data.dataset import ImageSequenceDataset from video_diffusion.common.util import get_time_string, get_function_args from video_diffusion.common.logger import get_logger_config_path from video_diffusion.common.image_util import log_train_samples, log_train_reg_samples from video_diffusion.common.instantiate_from_config import instantiate_from_config, get_obj_from_str from video_diffusion.pipelines.validation_loop import SampleLogger def collate_fn(examples): batch = { "prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0), "images": torch.stack([example["images"] for example in examples]), } if "class_images" in examples[0]: batch["class_prompt_ids"] = torch.cat([example["class_prompt_ids"] for example in examples], dim=0) batch["class_images"] = torch.stack([example["class_images"] for example in examples]) return batch def train( config: str, pretrained_model_path: str, train_dataset: Dict, logdir: str = None, train_steps: int = 300, validation_steps: int = 1000, validation_sample_logger_config: Optional[Dict] = None, test_pipeline_config: Optional[Dict] = dict(), trainer_pipeline_config: Optional[Dict] = dict(), gradient_accumulation_steps: int = 1, seed: Optional[int] = None, mixed_precision: Optional[str] = "fp16", enable_xformers: bool = True, train_batch_size: int = 1, learning_rate: float = 3e-5, scale_lr: bool = False, lr_scheduler: str = "constant", # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: int = 0, use_8bit_adam: bool = True, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_weight_decay: float = 1e-2, adam_epsilon: float = 1e-08, max_grad_norm: float = 1.0, gradient_checkpointing: bool = False, train_temporal_conv: bool = False, checkpointing_steps: int = 1000, model_config: dict={}, # use_train_latents: bool=False, # kwr # **kwargs ): args = get_function_args() # args.update(kwargs) train_dataset_config = copy.deepcopy(train_dataset) time_string = get_time_string() if logdir is None: logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '') logdir += f"_{time_string}" accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision, ) if accelerator.is_main_process: os.makedirs(logdir, exist_ok=True) OmegaConf.save(args, os.path.join(logdir, "config.yml")) logger = get_logger_config_path(logdir) if seed is not None: set_seed(seed) # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained( pretrained_model_path, subfolder="tokenizer", use_fast=False, ) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( pretrained_model_path, subfolder="text_encoder", ) vae = AutoencoderKL.from_pretrained( pretrained_model_path, subfolder="vae", ) unet = UNetPseudo3DConditionModel.from_2d_model( os.path.join(pretrained_model_path, "unet"), model_config=model_config ) if 'target' not in test_pipeline_config: test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline' pipeline = instantiate_from_config( test_pipeline_config, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=DDIMScheduler.from_pretrained( pretrained_model_path, subfolder="scheduler", ), ) pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps']) pipeline.set_progress_bar_config(disable=True) if is_xformers_available() and enable_xformers: # if False: # Disable xformers for null inversion try: pipeline.enable_xformers_memory_efficient_attention() print('enable xformers in the training and testing') except Exception as e: logger.warning( "Could not enable memory efficient attention. Make sure xformers is installed" f" correctly and a GPU is available: {e}" ) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.requires_grad_(False) # Start of config trainable parameters in Unet and optimizer trainable_modules = ("attn_temporal", ".to_q") if train_temporal_conv: trainable_modules += ("conv_temporal",) for name, module in unet.named_modules(): if name.endswith(trainable_modules): for params in module.parameters(): params.requires_grad = True if gradient_checkpointing: print('enable gradient checkpointing in the training and testing') unet.enable_gradient_checkpointing() if scale_lr: learning_rate = ( learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW params_to_optimize = unet.parameters() num_trainable_modules = 0 num_trainable_params = 0 num_unet_params = 0 for params in params_to_optimize: num_unet_params += params.numel() if params.requires_grad == True: num_trainable_modules +=1 num_trainable_params += params.numel() logger.info(f"Num of trainable modules: {num_trainable_modules}") logger.info(f"Num of trainable params: {num_trainable_params/(1024*1024):.2f} M") logger.info(f"Num of unet params: {num_unet_params/(1024*1024):.2f} M ") params_to_optimize = unet.parameters() optimizer = optimizer_class( params_to_optimize, lr=learning_rate, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon, ) # End of config trainable parameters in Unet and optimizer prompt_ids = tokenizer( train_dataset["prompt"], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", ).input_ids if 'class_data_root' in train_dataset_config: if 'class_data_prompt' not in train_dataset_config: train_dataset_config['class_data_prompt'] = train_dataset_config['prompt'] class_prompt_ids = tokenizer( train_dataset_config["class_data_prompt"], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", ).input_ids else: class_prompt_ids = None train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids, class_prompt_ids=class_prompt_ids) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn, ) train_sample_save_path = os.path.join(logdir, "train_samples.gif") log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader) if 'class_data_root' in train_dataset_config: log_train_reg_samples(save_path=train_sample_save_path.replace('train_samples', 'class_data_samples'), train_dataloader=train_dataloader) # Prepare learning rate scheduler in accelerate config lr_scheduler = get_scheduler( lr_scheduler, optimizer=optimizer, num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, num_training_steps=train_steps * gradient_accumulation_steps, ) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) accelerator.register_for_checkpointing(lr_scheduler) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 print('enable float16 in the training and testing') elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("video") # , config=vars(args)) # Start of config trainer trainer = instantiate_from_config( trainer_pipeline_config, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler= DDPMScheduler.from_pretrained( pretrained_model_path, subfolder="scheduler", ), # training hyperparams weight_dtype=weight_dtype, accelerator=accelerator, optimizer=optimizer, max_grad_norm=max_grad_norm, lr_scheduler=lr_scheduler, prior_preservation=None ) trainer.print_pipeline(logger) # Train! total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Instantaneous batch size per device = {train_batch_size}") logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Total optimization steps = {train_steps}") step = 0 # End of config trainer if validation_sample_logger_config is not None and accelerator.is_main_process: validation_sample_logger = SampleLogger(**validation_sample_logger_config, logdir=logdir) # Only show the progress bar once on each machine. progress_bar = tqdm( range(step, train_steps), disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") def make_data_yielder(dataloader): while True: for batch in dataloader: yield batch accelerator.wait_for_everyone() train_data_yielder = make_data_yielder(train_dataloader) assert(train_dataset.overfit_length == 1), "Only support overfiting on a single video" # batch = next(train_data_yielder) while step < train_steps: batch = next(train_data_yielder) """************************* start of an iteration*******************************""" loss = trainer.step(batch) # torch.cuda.empty_cache() """************************* end of an iteration*******************************""" # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) step += 1 if accelerator.is_main_process: if validation_sample_logger is not None and (step % validation_steps == 0): unet.eval() val_image = rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w") # Unet is changing in different iteration; we should invert online if validation_sample_logger_config.get('use_train_latents', False): # Precompute the latents for this video to align the initial latents in training and test assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video" # we only inference for latents, no training vae.eval() text_encoder.eval() unet.eval() text_embeddings = pipeline._encode_prompt( train_dataset.prompt, device = accelerator.device, num_images_per_prompt = 1, do_classifier_free_guidance = True, negative_prompt=None ) batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted( rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"), batch_size = 1 , num_images_per_prompt = 1, # not sure how to use it text_embeddings = text_embeddings ) batch['ddim_init_latents'] = batch['latents_all_step'][-1] else: batch['ddim_init_latents'] = None validation_sample_logger.log_sample_images( # image=rearrange(train_dataset.get_all()["images"].to(accelerator.device, dtype=weight_dtype), "c f h w -> f c h w"), # torch.Size([8, 3, 512, 512]) image= val_image, # torch.Size([8, 3, 512, 512]) pipeline=pipeline, device=accelerator.device, step=step, latents = batch['ddim_init_latents'], ) torch.cuda.empty_cache() unet.train() if step % checkpointing_steps == 0: accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( inspect.signature(accelerator.unwrap_model).parameters.keys() ) extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} pipeline_save = get_obj_from_str(test_pipeline_config["target"]).from_pretrained( pretrained_model_path, unet=accelerator.unwrap_model(unet, **extra_args), ) checkpoint_save_path = os.path.join(logdir, f"checkpoint_{step}") pipeline_save.save_pretrained(checkpoint_save_path) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=step) accelerator.end_training() @click.command() @click.option("--config", type=str, default="config/sample.yml") def run(config): train(config=config, **OmegaConf.load(config)) if __name__ == "__main__": run()