Ramzes / examples /boft_controlnet /train_controlnet.py
Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
#!/usr/bin/env python
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The implementation is based on "Parameter-Efficient Orthogonal Finetuning
# via Butterfly Factorization" (https://huggingface.co/papers/2311.06243) in ICLR 2024.
import itertools
import logging
import math
import os
from pathlib import Path
import datasets
import diffusers
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDIMScheduler,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from packaging import version
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from utils.args_loader import (
import_model_class_from_model_name_or_path,
parse_args,
)
from utils.dataset import collate_fn, log_validation, make_dataset
from utils.light_controlnet import ControlNetModel
from utils.tracemalloc import TorchTracemalloc, b2mb
from utils.unet_2d_condition import UNet2DConditionNewModel
from peft import BOFTConfig, get_peft_model
from peft.peft_model import PeftModel
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.16.0.dev0")
logger = get_logger(__name__)
UNET_TARGET_MODULES = ["to_q", "to_v", "to_k", "query", "value", "key"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
@torch.no_grad()
def save_adaptor(accelerator, output_dir, nets_dict):
for net_key in nets_dict.keys():
net_model = nets_dict[net_key]
unwarpped_net = accelerator.unwrap_model(net_model)
if isinstance(unwarpped_net, PeftModel):
unwarpped_net.save_pretrained(
os.path.join(output_dir, net_key),
state_dict=accelerator.get_state_dict(net_model),
safe_serialization=True,
)
else:
accelerator.save_model(
unwarpped_net,
os.path.join(output_dir, net_key),
safe_serialization=True,
)
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=logging_dir,
)
if args.report_to == "wandb":
wandb_init = {
"wandb": {
"name": args.wandb_run_name,
"mode": "online",
}
}
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionNewModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
controlnet = ControlNetModel()
if args.controlnet_model_name_or_path != "":
logger.info(f"Loading existing controlnet weights from {args.controlnet_model_name_or_path}")
controlnet.load_state_dict(torch.load(args.controlnet_model_name_or_path))
if args.use_boft:
config = BOFTConfig(
boft_block_size=args.boft_block_size,
boft_block_num=args.boft_block_num,
boft_n_butterfly_factor=args.boft_n_butterfly_factor,
target_modules=UNET_TARGET_MODULES,
boft_dropout=args.boft_dropout,
bias=args.boft_bias,
)
unet = get_peft_model(unet, config)
unet.print_trainable_parameters()
vae.requires_grad_(False)
controlnet.requires_grad_(True)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
unet.train()
controlnet.train()
if args.train_text_encoder and args.use_boft:
config = BOFTConfig(
boft_block_size=args.boft_block_size,
boft_block_num=args.boft_block_num,
boft_n_butterfly_factor=args.boft_n_butterfly_factor,
target_modules=TEXT_ENCODER_TARGET_MODULES,
boft_dropout=args.boft_dropout,
bias=args.boft_bias,
)
text_encoder = get_peft_model(text_encoder, config, adapter_name=args.wandb_run_name)
text_encoder.print_trainable_parameters()
if args.train_text_encoder:
text_encoder.train()
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
controlnet.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
if args.enable_xformers_memory_efficient_attention:
if accelerator.device.type == "xpu":
logger.warning("XPU doesn't support xformers yet, xformers is not applied.")
elif is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
controlnet.enable_xformers_memory_efficient_attention()
if args.train_text_encoder and not (args.use_lora or args.use_boft or args.use_oft):
text_encoder.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
controlnet.enable_gradient_checkpointing()
unet.enable_gradient_checkpointing()
if args.train_text_encoder and not (args.use_lora or args.use_boft or args.use_oft):
text_encoder.gradient_checkpointing_enable()
# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if accelerator.unwrap_model(controlnet).dtype != torch.float32:
raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"UNet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = [param for param in controlnet.parameters() if param.requires_grad]
params_to_optimize += [param for param in unet.parameters() if param.requires_grad]
if args.train_text_encoder:
params_to_optimize += [param for param in text_encoder.parameters() if param.requires_grad]
# Optimizer creation
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Load the dataset
train_dataset = make_dataset(args, tokenizer, accelerator, "train")
val_dataset = make_dataset(args, tokenizer, accelerator, "test")
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
if args.train_text_encoder:
text_encoder = accelerator.prepare(text_encoder)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers(args.wandb_project_name, config=vars(args), init_kwargs=wandb_init)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
if "checkpoint-current" in dirs:
path = "checkpoint-current"
dirs = [d for d in dirs if d.startswith("checkpoint") and d.endswith("0")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
else:
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
if path.split("-")[1] == "current":
global_step = int(dirs[-1].split("-")[1])
else:
global_step = int(path.split("-")[1])
initial_global_step = global_step
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
with TorchTracemalloc() as tracemalloc:
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
if args.report_to == "wandb":
accelerator.print(progress_bar)
continue
with accelerator.accumulate(controlnet), accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
# Get the guided hint for the UNet (320 dim)
guided_hint = controlnet(
controlnet_cond=controlnet_image,
)
# Predict the noise residual
model_pred = unet(
noisy_latents,
timesteps,
guided_hint=guided_hint,
encoder_hidden_states=encoder_hidden_states,
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(controlnet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else itertools.chain(
controlnet.parameters(),
)
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
if args.report_to == "wandb":
accelerator.print(progress_bar)
global_step += 1
step_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
if accelerator.is_main_process:
if global_step % args.validation_steps == 0 or global_step == 1:
logger.info(f"Running validation... \n Generating {args.num_validation_images} images.")
logger.info("Running validation... ")
with torch.no_grad():
log_validation(val_dataset, text_encoder, unet, controlnet, args, accelerator)
if global_step % args.checkpointing_steps == 0:
save_adaptor(accelerator, step_save_path, {"controlnet": controlnet, "unet": unet})
# save text_encoder if any
if args.train_text_encoder:
save_adaptor(accelerator, step_save_path, {"text_encoder": text_encoder})
accelerator.save_state(step_save_path)
logger.info(f"Saved {global_step} state to {step_save_path}")
logger.info(f"Saved current state to {step_save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
accelerator.print(
f"{accelerator.device.type.upper()} Memory before entering the train : {b2mb(tracemalloc.begin)}"
)
accelerator.print(
f"{accelerator.device.type.upper()} Memory consumed at the end of the train (end-begin): {tracemalloc.used}"
)
accelerator.print(
f"{accelerator.device.type.upper()} Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}"
)
accelerator.print(
f"{accelerator.device.type.upper()} Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
)
accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")
accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}")
accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}")
accelerator.print(
f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}"
)
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)