# ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* import argparse import gc import hashlib import itertools import logging import math import os import warnings from pathlib import Path import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.models.attention_processor import ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, LoRAAttnProcessor2_0, SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.17.0") logger = get_logger(__name__) def save_model_card( repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, pipeline: DiffusionPipeline = None, ): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f"![img_{i}](./image_{i}.png)\n" yaml = f""" --- license: creativeml-openrail-m base_model: {base_model} instance_prompt: {prompt} tags: - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} - text-to-image - diffusers - lora inference: true --- """ model_card = f""" # LoRA DreamBooth - {repo_id} These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n {img_str} LoRA for the text encoder was enabled: {train_text_encoder}. """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation elif model_class == "T5EncoderModel": from transformers import T5EncoderModel return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--instance_data_dir", type=str, default=None, required=True, help="A folder containing the training data of instance images.", ) parser.add_argument( "--class_data_dir", type=str, default=None, required=False, help="A folder containing the training data of class images.", ) parser.add_argument( "--instance_prompt", type=str, default=None, required=True, help="The prompt with identifier specifying the instance", ) parser.add_argument( "--class_prompt", type=str, default=None, help="The prompt to specify images in the same class as provided instance images.", ) parser.add_argument( "--validation_prompt", type=str, default=None, help="A prompt that is used during validation to verify that the model is learning.", ) parser.add_argument( "--num_validation_images", type=int, default=4, help="Number of images that should be generated during validation with `validation_prompt`.", ) parser.add_argument( "--validation_epochs", type=int, default=50, help=( "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) parser.add_argument( "--with_prior_preservation", default=False, action="store_true", help="Flag to add prior preservation loss.", ) parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, default=100, help=( "Minimal class images for prior preservation loss. If there are not enough images already present in" " class_data_dir, additional images will be sampled with class_prompt." ), ) parser.add_argument( "--output_dir", type=str, default="lora-dreambooth-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--center_crop", default=False, action="store_true", help=( "Whether to center crop the input images to the resolution. If not set, the images will be randomly" " cropped. The images will be resized to the resolution first before cropping." ), ) parser.add_argument( "--train_text_encoder", action="store_true", help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" " training using `--resume_from_checkpoint`." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=( "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" " for more docs" ), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=5e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"], help=( "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--pre_compute_text_embeddings", action="store_true", help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", ) parser.add_argument( "--tokenizer_max_length", type=int, default=None, required=False, help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", ) parser.add_argument( "--text_encoder_use_attention_mask", action="store_true", required=False, help="Whether to use attention mask for the text encoder", ) parser.add_argument( "--validation_images", required=False, default=None, nargs="+", help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", ) parser.add_argument( "--class_labels_conditioning", required=False, default=None, help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", ) parser.add_argument( "--lora_rank", type=int, default=4, help="rank of lora." ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: # logger is not available yet if args.class_data_dir is not None: warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") if args.train_text_encoder and args.pre_compute_text_embeddings: raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") return args class DreamBoothDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts. """ def __init__( self, instance_data_root, instance_prompt, tokenizer, class_data_root=None, class_prompt=None, class_num=None, size=512, center_crop=False, encoder_hidden_states=None, instance_prompt_encoder_hidden_states=None, tokenizer_max_length=None, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer self.encoder_hidden_states = encoder_hidden_states self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states self.tokenizer_max_length = tokenizer_max_length self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) if class_num is not None: self.num_class_images = min(len(self.class_images_path), class_num) else: self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: self.class_data_root = None self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return self._length def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = exif_transpose(instance_image) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) if self.encoder_hidden_states is not None: example["instance_prompt_ids"] = self.encoder_hidden_states else: text_inputs = tokenize_prompt( self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length ) example["instance_prompt_ids"] = text_inputs.input_ids example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) if self.instance_prompt_encoder_hidden_states is not None: example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states else: class_text_inputs = tokenize_prompt( self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length ) example["class_prompt_ids"] = class_text_inputs.input_ids example["class_attention_mask"] = class_text_inputs.attention_mask return example def collate_fn(examples, with_prior_preservation=False): has_attention_mask = "instance_attention_mask" in examples[0] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] if has_attention_mask: attention_mask = [example["instance_attention_mask"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] if has_attention_mask: attention_mask += [example["class_attention_mask"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.cat(input_ids, dim=0) batch = { "input_ids": input_ids, "pixel_values": pixel_values, } if has_attention_mask: batch["attention_mask"] = attention_mask return batch class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples def __len__(self): return self.num_samples def __getitem__(self, index): example = {} example["prompt"] = self.prompt example["index"] = index return example def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): if tokenizer_max_length is not None: max_length = tokenizer_max_length else: max_length = tokenizer.model_max_length text_inputs = tokenizer( prompt, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt", ) return text_inputs def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): text_input_ids = input_ids.to(text_encoder.device) if text_encoder_use_attention_mask: attention_mask = attention_mask.to(text_encoder.device) else: attention_mask = None prompt_embeds = text_encoder( text_input_ids, attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] return prompt_embeds def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, logging_dir=logging_dir, project_config=accelerator_project_config, ) if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") import wandb # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Generate class images if prior preservation is enabled. if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): class_images_dir.mkdir(parents=True) cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, revision=args.revision, ) pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) try: vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision ) except OSError: # IF does not have a VAE so let's just set it to None # We don't have to error out here vae = None unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # We only train the additional adapter LoRA layers if vae is not None: vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) if vae is not None: vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") # now we will add new LoRA weights to the attention layers # It's important to realize here how many attention weights will be added and of which sizes # The sizes of the attention layers consist only of two different variables: # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. # Let's first see how many attention processors we will have to set. # For Stable Diffusion, it should be equal to: # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 # => 32 layers # Set correct lora layers unet_lora_attn_procs = {} for name, attn_processor in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): lora_attn_processor_class = LoRAAttnAddedKVProcessor else: lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) unet_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.lora_rank ) unet.set_attn_processor(unet_lora_attn_procs) unet_lora_layers = AttnProcsLayers(unet.attn_processors) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, # we first load a dummy pipeline with the text encoder and then do the monkey-patching. text_encoder_lora_layers = None if args.train_text_encoder: text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): if name.endswith(TEXT_ENCODER_ATTN_MODULE): text_lora_attn_procs[name] = LoRAAttnProcessor( hidden_size=module.out_proj.out_features, cross_attention_dim=None ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) temp_pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=text_encoder ) temp_pipeline._modify_text_encoder(text_lora_attn_procs) text_encoder = temp_pipeline.text_encoder del temp_pipeline # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): # there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder atten layers unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None if args.train_text_encoder: text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() for model in models: state_dict = model.state_dict() if ( text_encoder_lora_layers is not None and text_encoder_keys is not None and state_dict.keys() == text_encoder_keys ): # text encoder text_encoder_lora_layers_to_save = state_dict elif state_dict.keys() == unet_keys: # unet unet_lora_layers_to_save = state_dict # make sure to pop weight so that corresponding model is not saved again weights.pop() LoraLoaderMixin.save_lora_weights( output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save, ) def load_model_hook(models, input_dir): # Note we DON'T pass the unet and text encoder here an purpose # so that the we don't accidentally override the LoRA layers of # unet_lora_layers and text_encoder_lora_layers which are stored in `models` # with new torch.nn.Modules / weights. We simply use the pipeline class as # an easy way to load the lora checkpoints temp_pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype, ) temp_pipeline.load_lora_weights(input_dir) # load lora weights into models models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) if len(models) > 1: models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) # delete temporary pipeline and pop models del temp_pipeline for _ in range(len(models)): models.pop() accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW # Optimizer creation params_to_optimize = ( itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) if args.train_text_encoder else unet_lora_layers.parameters() ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) if args.pre_compute_text_embeddings: def compute_text_embeddings(prompt): with torch.no_grad(): text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) prompt_embeds = encode_prompt( text_encoder, text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) return prompt_embeds pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) validation_prompt_negative_prompt_embeds = compute_text_embeddings("") if args.validation_prompt is not None: validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) else: validation_prompt_encoder_hidden_states = None if args.instance_prompt is not None: pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) else: pre_computed_instance_prompt_encoder_hidden_states = None text_encoder = None tokenizer = None gc.collect() torch.cuda.empty_cache() else: pre_computed_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None validation_prompt_negative_prompt_embeds = None pre_computed_instance_prompt_encoder_hidden_states = None # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, class_num=args.num_class_images, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, encoder_hidden_states=pre_computed_encoder_hidden_states, instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, tokenizer_max_length=args.tokenizer_max_length, ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. if args.train_text_encoder: unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler ) else: unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet_lora_layers, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("dreambooth-lora", config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the mos recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if step % args.gradient_accumulation_steps == 0: progress_bar.update(1) continue with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) if vae is not None: # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist model_input = model_input.sample() * vae.config.scaling_factor else: model_input = pixel_values # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz, channels, height, width = model_input.shape # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device ) timesteps = timesteps.long() # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Get the text embedding for conditioning if args.pre_compute_text_embeddings: encoder_hidden_states = batch["input_ids"] else: encoder_hidden_states = encode_prompt( text_encoder, batch["input_ids"], batch["attention_mask"], text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) if accelerator.unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": class_labels = timesteps else: class_labels = None # Predict the noise residual model_pred = unet( noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels ).sample # if model predicts variance, throw away the prediction. we will only train on the # simplified training objective. This means that all schedulers using the fine tuned # model must be configured to use one of the fixed variance variance types. if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) if args.train_text_encoder else unet_lora_layers.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), revision=args.revision, torch_dtype=weight_dtype, ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it scheduler_args = {} if "variance_type" in pipeline.scheduler.config: variance_type = pipeline.scheduler.config.variance_type if variance_type in ["learned", "learned_range"]: variance_type = "fixed_small" scheduler_args["variance_type"] = variance_type pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config, **scheduler_args ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None if args.pre_compute_text_embeddings: pipeline_args = { "prompt_embeds": validation_prompt_encoder_hidden_states, "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, } else: pipeline_args = {"prompt": args.validation_prompt} if args.validation_images is None: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) ] else: images = [] for image in args.validation_images: image = Image.open(image) image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { "validation": [ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) del pipeline torch.cuda.empty_cache() # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) if text_encoder is not None: text_encoder = text_encoder.to(torch.float32) text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, ) accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args)