Spaces:
Runtime error
Runtime error
import os,copy | |
import inspect | |
from typing import Optional, List, Dict, Union | |
import PIL | |
import click | |
from omegaconf import OmegaConf | |
import torch | |
import torch.utils.data | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from accelerate import Accelerator | |
from accelerate.utils import set_seed | |
from diffusers import ( | |
AutoencoderKL, | |
DDPMScheduler, | |
DDIMScheduler, | |
UNet2DConditionModel, | |
) | |
from diffusers.optimization import get_scheduler | |
from diffusers.utils.import_utils import is_xformers_available | |
from diffusers.pipeline_utils import DiffusionPipeline | |
from tqdm.auto import tqdm | |
from transformers import AutoTokenizer, CLIPTextModel | |
from einops import rearrange | |
from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel | |
from video_diffusion.data.dataset import ImageSequenceDataset | |
from video_diffusion.common.util import get_time_string, get_function_args | |
from video_diffusion.common.logger import get_logger_config_path | |
from video_diffusion.common.image_util import log_train_samples, log_train_reg_samples | |
from video_diffusion.common.instantiate_from_config import instantiate_from_config, get_obj_from_str | |
from video_diffusion.pipelines.validation_loop import SampleLogger | |
def collate_fn(examples): | |
batch = { | |
"prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0), | |
"images": torch.stack([example["images"] for example in examples]), | |
} | |
if "class_images" in examples[0]: | |
batch["class_prompt_ids"] = torch.cat([example["class_prompt_ids"] for example in examples], dim=0) | |
batch["class_images"] = torch.stack([example["class_images"] for example in examples]) | |
return batch | |
def train( | |
config: str, | |
pretrained_model_path: str, | |
train_dataset: Dict, | |
logdir: str = None, | |
train_steps: int = 300, | |
validation_steps: int = 1000, | |
validation_sample_logger_config: Optional[Dict] = None, | |
test_pipeline_config: Optional[Dict] = dict(), | |
trainer_pipeline_config: Optional[Dict] = dict(), | |
gradient_accumulation_steps: int = 1, | |
seed: Optional[int] = None, | |
mixed_precision: Optional[str] = "fp16", | |
enable_xformers: bool = True, | |
train_batch_size: int = 1, | |
learning_rate: float = 3e-5, | |
scale_lr: bool = False, | |
lr_scheduler: str = "constant", # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] | |
lr_warmup_steps: int = 0, | |
use_8bit_adam: bool = True, | |
adam_beta1: float = 0.9, | |
adam_beta2: float = 0.999, | |
adam_weight_decay: float = 1e-2, | |
adam_epsilon: float = 1e-08, | |
max_grad_norm: float = 1.0, | |
gradient_checkpointing: bool = False, | |
train_temporal_conv: bool = False, | |
checkpointing_steps: int = 1000, | |
model_config: dict={}, | |
# use_train_latents: bool=False, | |
# kwr | |
# **kwargs | |
): | |
args = get_function_args() | |
# args.update(kwargs) | |
train_dataset_config = copy.deepcopy(train_dataset) | |
time_string = get_time_string() | |
if logdir is None: | |
logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '') | |
logdir += f"_{time_string}" | |
accelerator = Accelerator( | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
mixed_precision=mixed_precision, | |
) | |
if accelerator.is_main_process: | |
os.makedirs(logdir, exist_ok=True) | |
OmegaConf.save(args, os.path.join(logdir, "config.yml")) | |
logger = get_logger_config_path(logdir) | |
if seed is not None: | |
set_seed(seed) | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
pretrained_model_path, | |
subfolder="tokenizer", | |
use_fast=False, | |
) | |
# Load models and create wrapper for stable diffusion | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_path, | |
subfolder="text_encoder", | |
) | |
vae = AutoencoderKL.from_pretrained( | |
pretrained_model_path, | |
subfolder="vae", | |
) | |
unet = UNetPseudo3DConditionModel.from_2d_model( | |
os.path.join(pretrained_model_path, "unet"), model_config=model_config | |
) | |
if 'target' not in test_pipeline_config: | |
test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline' | |
pipeline = instantiate_from_config( | |
test_pipeline_config, | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=DDIMScheduler.from_pretrained( | |
pretrained_model_path, | |
subfolder="scheduler", | |
), | |
) | |
pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps']) | |
pipeline.set_progress_bar_config(disable=True) | |
if is_xformers_available() and enable_xformers: | |
# if False: # Disable xformers for null inversion | |
try: | |
pipeline.enable_xformers_memory_efficient_attention() | |
print('enable xformers in the training and testing') | |
except Exception as e: | |
logger.warning( | |
"Could not enable memory efficient attention. Make sure xformers is installed" | |
f" correctly and a GPU is available: {e}" | |
) | |
vae.requires_grad_(False) | |
unet.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
# Start of config trainable parameters in Unet and optimizer | |
trainable_modules = ("attn_temporal", ".to_q") | |
if train_temporal_conv: | |
trainable_modules += ("conv_temporal",) | |
for name, module in unet.named_modules(): | |
if name.endswith(trainable_modules): | |
for params in module.parameters(): | |
params.requires_grad = True | |
if gradient_checkpointing: | |
print('enable gradient checkpointing in the training and testing') | |
unet.enable_gradient_checkpointing() | |
if scale_lr: | |
learning_rate = ( | |
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes | |
) | |
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | |
if use_8bit_adam: | |
try: | |
import bitsandbytes as bnb | |
except ImportError: | |
raise ImportError( | |
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | |
) | |
optimizer_class = bnb.optim.AdamW8bit | |
else: | |
optimizer_class = torch.optim.AdamW | |
params_to_optimize = unet.parameters() | |
num_trainable_modules = 0 | |
num_trainable_params = 0 | |
num_unet_params = 0 | |
for params in params_to_optimize: | |
num_unet_params += params.numel() | |
if params.requires_grad == True: | |
num_trainable_modules +=1 | |
num_trainable_params += params.numel() | |
logger.info(f"Num of trainable modules: {num_trainable_modules}") | |
logger.info(f"Num of trainable params: {num_trainable_params/(1024*1024):.2f} M") | |
logger.info(f"Num of unet params: {num_unet_params/(1024*1024):.2f} M ") | |
params_to_optimize = unet.parameters() | |
optimizer = optimizer_class( | |
params_to_optimize, | |
lr=learning_rate, | |
betas=(adam_beta1, adam_beta2), | |
weight_decay=adam_weight_decay, | |
eps=adam_epsilon, | |
) | |
# End of config trainable parameters in Unet and optimizer | |
prompt_ids = tokenizer( | |
train_dataset["prompt"], | |
truncation=True, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
if 'class_data_root' in train_dataset_config: | |
if 'class_data_prompt' not in train_dataset_config: | |
train_dataset_config['class_data_prompt'] = train_dataset_config['prompt'] | |
class_prompt_ids = tokenizer( | |
train_dataset_config["class_data_prompt"], | |
truncation=True, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
else: | |
class_prompt_ids = None | |
train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids, class_prompt_ids=class_prompt_ids) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_batch_size, | |
shuffle=True, | |
num_workers=0, | |
collate_fn=collate_fn, | |
) | |
train_sample_save_path = os.path.join(logdir, "train_samples.gif") | |
log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader) | |
if 'class_data_root' in train_dataset_config: | |
log_train_reg_samples(save_path=train_sample_save_path.replace('train_samples', 'class_data_samples'), train_dataloader=train_dataloader) | |
# Prepare learning rate scheduler in accelerate config | |
lr_scheduler = get_scheduler( | |
lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
num_training_steps=train_steps * gradient_accumulation_steps, | |
) | |
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
unet, optimizer, train_dataloader, lr_scheduler | |
) | |
accelerator.register_for_checkpointing(lr_scheduler) | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
print('enable float16 in the training and testing') | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# Move text_encode and vae to gpu. | |
# For mixed precision training we cast the text_encoder and vae weights to half-precision | |
# as these models are only used for inference, keeping weights in full precision is not required. | |
vae.to(accelerator.device, dtype=weight_dtype) | |
text_encoder.to(accelerator.device, dtype=weight_dtype) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
if accelerator.is_main_process: | |
accelerator.init_trackers("video") # , config=vars(args)) | |
# Start of config trainer | |
trainer = instantiate_from_config( | |
trainer_pipeline_config, | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler= DDPMScheduler.from_pretrained( | |
pretrained_model_path, | |
subfolder="scheduler", | |
), | |
# training hyperparams | |
weight_dtype=weight_dtype, | |
accelerator=accelerator, | |
optimizer=optimizer, | |
max_grad_norm=max_grad_norm, | |
lr_scheduler=lr_scheduler, | |
prior_preservation=None | |
) | |
trainer.print_pipeline(logger) | |
# Train! | |
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num batches each epoch = {len(train_dataloader)}") | |
logger.info(f" Instantaneous batch size per device = {train_batch_size}") | |
logger.info( | |
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" | |
) | |
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {train_steps}") | |
step = 0 | |
# End of config trainer | |
if validation_sample_logger_config is not None and accelerator.is_main_process: | |
validation_sample_logger = SampleLogger(**validation_sample_logger_config, logdir=logdir) | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm( | |
range(step, train_steps), | |
disable=not accelerator.is_local_main_process, | |
) | |
progress_bar.set_description("Steps") | |
def make_data_yielder(dataloader): | |
while True: | |
for batch in dataloader: | |
yield batch | |
accelerator.wait_for_everyone() | |
train_data_yielder = make_data_yielder(train_dataloader) | |
assert(train_dataset.overfit_length == 1), "Only support overfiting on a single video" | |
# batch = next(train_data_yielder) | |
while step < train_steps: | |
batch = next(train_data_yielder) | |
"""************************* start of an iteration*******************************""" | |
loss = trainer.step(batch) | |
# torch.cuda.empty_cache() | |
"""************************* end of an iteration*******************************""" | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
step += 1 | |
if accelerator.is_main_process: | |
if validation_sample_logger is not None and (step % validation_steps == 0): | |
unet.eval() | |
val_image = rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w") | |
# Unet is changing in different iteration; we should invert online | |
if validation_sample_logger_config.get('use_train_latents', False): | |
# Precompute the latents for this video to align the initial latents in training and test | |
assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video" | |
# we only inference for latents, no training | |
vae.eval() | |
text_encoder.eval() | |
unet.eval() | |
text_embeddings = pipeline._encode_prompt( | |
train_dataset.prompt, | |
device = accelerator.device, | |
num_images_per_prompt = 1, | |
do_classifier_free_guidance = True, | |
negative_prompt=None | |
) | |
batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted( | |
rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"), | |
batch_size = 1 , | |
num_images_per_prompt = 1, # not sure how to use it | |
text_embeddings = text_embeddings | |
) | |
batch['ddim_init_latents'] = batch['latents_all_step'][-1] | |
else: | |
batch['ddim_init_latents'] = None | |
validation_sample_logger.log_sample_images( | |
# image=rearrange(train_dataset.get_all()["images"].to(accelerator.device, dtype=weight_dtype), "c f h w -> f c h w"), # torch.Size([8, 3, 512, 512]) | |
image= val_image, # torch.Size([8, 3, 512, 512]) | |
pipeline=pipeline, | |
device=accelerator.device, | |
step=step, | |
latents = batch['ddim_init_latents'], | |
) | |
torch.cuda.empty_cache() | |
unet.train() | |
if step % checkpointing_steps == 0: | |
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( | |
inspect.signature(accelerator.unwrap_model).parameters.keys() | |
) | |
extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} | |
pipeline_save = get_obj_from_str(test_pipeline_config["target"]).from_pretrained( | |
pretrained_model_path, | |
unet=accelerator.unwrap_model(unet, **extra_args), | |
) | |
checkpoint_save_path = os.path.join(logdir, f"checkpoint_{step}") | |
pipeline_save.save_pretrained(checkpoint_save_path) | |
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
progress_bar.set_postfix(**logs) | |
accelerator.log(logs, step=step) | |
accelerator.end_training() | |
def run(config): | |
train(config=config, **OmegaConf.load(config)) | |
if __name__ == "__main__": | |
run() | |