import os from config import Config import shutil import random import math import numpy as np import torch import torch.nn.functional as F import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ( DistributedDataParallelKwargs, ProjectConfiguration, set_seed, ) from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers logger = get_logger(__name__) from diffusers import ( AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, compute_snr, ) from diffusers.utils import ( convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft, is_wandb_available, is_xformers_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module logger = get_logger(__name__) def save_model_card( repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, train_text_encoder: bool = False, repo_folder: str = None, vae_path: str = None, ): img_str = "" if images is not None: for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f"![img_{i}](./image_{i}.png)\n" img_str = "" # Declare the img_str variable model_description = "SDXL Product Images" model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, license="creativeml-openrail-m", base_model=base_model, model_description=model_description, inference=True, ) tags = [ "stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "diffusers-training", ] model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "CLIPTextModelWithProjection": from transformers import CLIPTextModelWithProjection return CLIPTextModelWithProjection else: raise ValueError(f"{model_class} is not supported.") def tokenize_prompt(tokenizer, prompt): text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): prompt_embeds_list = [] for i, text_encoder in enumerate(text_encoders): if tokenizers is not None: tokenizer = tokenizers[i] text_input_ids = tokenize_prompt(tokenizer, prompt) else: assert text_input_ids_list is not None text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return prompt_embeds, pooled_prompt_embeds def main(): config = Config() from pathlib import Path from contextlib import nullcontext if config.report_to == "wandb" and config.hub_token is not None: raise ValueError( "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) logging_dir = Path(config.output_dir, config.logging_dir) if torch.backends.mps.is_available() and config.mixed_precision == "bf16": # due to pytorch#99272, MPS does not yet support bfloat16. raise ValueError( "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) accelerator_project_config = ProjectConfiguration( project_dir=config.output_dir, logging_dir=logging_dir ) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=config.gradient_accumulation_steps, mixed_precision=config.mixed_precision, log_with=config.report_to, project_config=accelerator_project_config, kwargs_handlers=[kwargs], ) import logging if config.report_to == "wandb": if not is_wandb_available(): raise ImportError( "Make sure to install wandb if you want to use it for logging during training." ) import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) from datasets import utils as datasets_utils logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets_utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: datasets_utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if config.seed is not None: set_seed(config.seed) # Handle the repository creation if accelerator.is_main_process: if config.output_dir is not None: os.makedirs(config.output_dir, exist_ok=True) if config.push_to_hub: repo_id = create_repo( repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True, token=config.hub_token, ).repo_id # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( config.pretrained_model_name_or_path, subfolder="tokenizer", revision=config.revision, use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( config.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=config.revision, use_fast=False, ) # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path( config.pretrained_model_name_or_path, config.revision ) text_encoder_cls_two = import_model_class_from_model_name_or_path( config.pretrained_model_name_or_path, config.revision, subfolder="text_encoder_2", ) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained( config.pretrained_model_name_or_path, subfolder="scheduler" ) text_encoder_one = text_encoder_cls_one.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision, variant=config.variant, ) text_encoder_two = text_encoder_cls_two.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=config.revision, variant=config.variant, ) vae_path = ( config.pretrained_model_name_or_path if config.pretrained_vae_model_name_or_path is None else config.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( vae_path, subfolder="vae" if config.pretrained_vae_model_name_or_path is None else None, revision=config.revision, variant=config.variant, ) unet = UNet2DConditionModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="unet", revision=config.revision, variant=config.variant, ) # We only train the additional adapter LoRA layers vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) unet.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move unet, vae and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. unet.to(accelerator.device, dtype=weight_dtype) if config.pretrained_vae_model_name_or_path is None: vae.to(accelerator.device, dtype=torch.float32) else: vae.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) if config.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warning( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError( "xformers is not available. Make sure it is installed correctly" ) # now we will add new LoRA weights to the attention layers # Set correct lora layers unet_lora_config = LoraConfig( r=config.rank, lora_alpha=config.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, we will also attach adapters to it. if config.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 text_lora_config = LoraConfig( r=config.rank, lora_alpha=config.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: # there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder attn layers unet_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None for model in models: if isinstance(unwrap_model(model), type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) elif isinstance( unwrap_model(model), type(unwrap_model(text_encoder_one)) ): text_encoder_one_lora_layers_to_save = ( convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) ) elif isinstance( unwrap_model(model), type(unwrap_model(text_encoder_two)) ): text_encoder_two_lora_layers_to_save = ( convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) ) else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again if weights: weights.pop() StableDiffusionXLPipeline.save_lora_weights( output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) def load_model_hook(models, input_dir): unet_ = None text_encoder_one_ = None text_encoder_two_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(unet))): unet_ = model elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, _ = LoraLoaderMixin.lora_state_dict(input_dir) unet_state_dict = { f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") } unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict( unet_, unet_state_dict, adapter_name="default" ) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) if config.train_text_encoder: _set_state_dict_into_text_encoder( lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_ ) _set_state_dict_into_text_encoder( lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_, ) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if config.mixed_precision == "fp16": models = [unet_] if config.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) cast_training_params(models, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) if config.gradient_checkpointing: unet.enable_gradient_checkpointing() if config.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if config.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if config.scale_lr: config.learning_rate = ( config.learning_rate * config.gradient_accumulation_steps * config.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. if config.mixed_precision == "fp16": models = [unet] if config.train_text_encoder: models.extend([text_encoder_one, text_encoder_two]) cast_training_params(models, dtype=torch.float32) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if config.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW # Optimizer creation params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) if config.train_text_encoder: params_to_optimize = ( params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) ) optimizer = optimizer_class( params_to_optimize, lr=config.learning_rate, betas=(config.adam_beta1, config.adam_beta2), weight_decay=config.adam_weight_decay, eps=config.adam_epsilon, ) # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if config.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( config.dataset_name, config.dataset_config_name, cache_dir=config.cache_dir, data_dir=config.train_data_dir, ) else: data_files = {} if config.train_data_dir is not None: data_files["test"] = os.path.join(config.train_data_dir, "**") dataset = load_dataset( "imagefolder", data_files=data_files, cache_dir=config.cache_dir, ) # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # Preprocessing the datasets. # We need to tokenize inputs and targets. column_names = dataset["test"].column_names # 6. Get the column names for input/target. DATASET_NAME_MAPPING = { "lambdalabs/pokemon-blip-captions": ("image", "text"), } dataset_columns = DATASET_NAME_MAPPING.get(config.dataset_name, None) if config.image_column is None: image_column = ( dataset_columns[0] if dataset_columns is not None else column_names[0] ) else: image_column = config.image_column if image_column not in column_names: raise ValueError( f"--image_column' value '{config.image_column}' needs to be one of: {', '.join(column_names)}" ) if config.caption_column is None: caption_column = ( dataset_columns[1] if dataset_columns is not None else column_names[1] ) else: caption_column = config.caption_column if caption_column not in column_names: raise ValueError( f"--caption_column' value '{config.caption_column}' needs to be one of: {', '.join(column_names)}" ) # Preprocessing the datasets. # We need to tokenize input captions and transform the images. def tokenize_captions(examples, is_train=True): captions = [] for caption in examples[caption_column]: if isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) tokens_one = tokenize_prompt(tokenizer_one, captions) tokens_two = tokenize_prompt(tokenizer_two, captions) return tokens_one, tokens_two # Preprocessing the datasets. train_resize = transforms.Resize( config.resolution, interpolation=transforms.InterpolationMode.BILINEAR ) train_crop = ( transforms.CenterCrop(config.resolution) if config.center_crop else transforms.RandomCrop(config.resolution) ) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] # image aug original_sizes = [] all_images = [] crop_top_lefts = [] for image in images: original_sizes.append((image.height, image.width)) image = train_resize(image) if config.random_flip and random.random() < 0.5: # flip image = train_flip(image) if config.center_crop: y1 = max(0, int(round((image.height - config.resolution) / 2.0))) x1 = max(0, int(round((image.width - config.resolution) / 2.0))) image = train_crop(image) else: y1, x1, h, w = train_crop.get_params( image, (config.resolution, config.resolution) ) image = crop(image, y1, x1, h, w) crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) image = train_transforms(image) all_images.append(image) examples["original_sizes"] = original_sizes examples["crop_top_lefts"] = crop_top_lefts examples["pixel_values"] = all_images tokens_one, tokens_two = tokenize_captions(examples) examples["input_ids_one"] = tokens_one examples["input_ids_two"] = tokens_two if config.debug_loss: fnames = [ os.path.basename(image.filename) for image in examples[image_column] if image.filename ] if fnames: examples["filenames"] = fnames return examples with accelerator.main_process_first(): if config.max_train_samples is not None: dataset["test"] = ( dataset["test"] .shuffle(seed=config.seed) .select(range(config.max_train_samples)) ) # Set the training transforms train_dataset = dataset["test"].with_transform( preprocess_train, output_all_columns=True ) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() original_sizes = [example["original_sizes"] for example in examples] crop_top_lefts = [example["crop_top_lefts"] for example in examples] input_ids_one = torch.stack([example["input_ids_one"] for example in examples]) input_ids_two = torch.stack([example["input_ids_two"] for example in examples]) result = { "pixel_values": pixel_values, "input_ids_one": input_ids_one, "input_ids_two": input_ids_two, "original_sizes": original_sizes, "crop_top_lefts": crop_top_lefts, } filenames = [ example["filenames"] for example in examples if "filenames" in example ] if filenames: result["filenames"] = filenames return result # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=config.train_batch_size, num_workers=config.dataloader_num_workers, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil( len(train_dataloader) / config.gradient_accumulation_steps ) if config.max_train_steps is None: config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps, num_training_steps=config.max_train_steps * config.gradient_accumulation_steps, ) # Prepare everything with our `accelerator`. if config.train_text_encoder: ( unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler, ) = accelerator.prepare( unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler, ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( len(train_dataloader) / config.gradient_accumulation_steps ) if overrode_max_train_steps: config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs config.num_train_epochs = math.ceil( config.max_train_steps / num_update_steps_per_epoch ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("text2image-fine-tune", config=vars(config)) # Train! total_batch_size = ( config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {config.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) logger.info(f" Gradient Accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {config.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if config.resume_from_checkpoint: if config.resume_from_checkpoint != "latest": path = os.path.basename(config.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(config.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{config.resume_from_checkpoint}' does not exist. Starting a new training run." ) config.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(config.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, config.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) for epoch in range(first_epoch, config.num_train_epochs): unet.train() if config.train_text_encoder: text_encoder_one.train() text_encoder_two.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space if config.pretrained_vae_model_name_or_path is not None: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"] model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if config.pretrained_vae_model_name_or_path is None: model_input = model_input.to(weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) if config.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += config.noise_offset * torch.randn( (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device, ) bsz = model_input.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, ) timesteps = timesteps.long() # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise( model_input, noise, timesteps ) # time ids def compute_time_ids(original_size, crops_coords_top_left): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids target_size = (config.resolution, config.resolution) add_time_ids = list( original_size + crops_coords_top_left + target_size ) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.to( accelerator.device, dtype=weight_dtype ) return add_time_ids add_time_ids = torch.cat( [ compute_time_ids(s, c) for s, c in zip( batch["original_sizes"], batch["crop_top_lefts"] ) ] ) # Predict the noise residual unet_added_conditions = {"time_ids": add_time_ids} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, prompt=None, text_input_ids_list=[ batch["input_ids_one"], batch["input_ids_two"], ], ) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet( noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions, return_dict=False, )[0] # Get the target for loss depending on the prediction type if config.prediction_type is not None: # set prediction_type of scheduler if defined noise_scheduler.register_to_config( prediction_type=config.prediction_type ) if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) if config.snr_gamma is None: loss = F.mse_loss( model_pred.float(), target.float(), reduction="mean" ) else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) mse_loss_weights = torch.stack( [snr, config.snr_gamma * torch.ones_like(timesteps)], dim=1 ).min(dim=1)[0] if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss( model_pred.float(), target.float(), reduction="none" ) loss = ( loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights ) loss = loss.mean() if config.debug_loss and "filenames" in batch: for fname in batch["filenames"]: accelerator.log({"loss_for_" + fname: loss}, step=global_step) # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather( loss.repeat(config.train_batch_size) ).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_( params_to_optimize, config.max_grad_norm ) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 if accelerator.is_main_process: if global_step % config.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if config.checkpoints_total_limit is not None: checkpoints = os.listdir(config.output_dir) checkpoints = [ d for d in checkpoints if d.startswith("checkpoint") ] checkpoints = sorted( checkpoints, key=lambda x: int(x.split("-")[1]) ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= config.checkpoints_total_limit: num_to_remove = ( len(checkpoints) - config.checkpoints_total_limit + 1 ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info( f"removing checkpoints: {', '.join(removing_checkpoints)}" ) for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join( config.output_dir, removing_checkpoint ) shutil.rmtree(removing_checkpoint) save_path = os.path.join( config.output_dir, f"checkpoint-{global_step}" ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= config.max_train_steps: break if accelerator.is_main_process: if ( config.validation_prompt is not None and epoch % config.validation_epochs == 0 ): logger.info( f"Running validation... \n Generating {config.num_validation_images} images with prompt:" f" {config.validation_prompt}." ) # create pipeline pipeline = StableDiffusionXLPipeline.from_pretrained( config.pretrained_model_name_or_path, vae=vae, text_encoder=unwrap_model(text_encoder_one), text_encoder_2=unwrap_model(text_encoder_two), unet=unwrap_model(unet), revision=config.revision, variant=config.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference generator = ( torch.Generator(device=accelerator.device).manual_seed(config.seed) if config.seed else None ) pipeline_args = {"prompt": config.validation_prompt} if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type) with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(config.num_validation_images) ] for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( "validation", np_images, epoch, dataformats="NHWC" ) if tracker.name == "wandb": tracker.log( { "validation": [ wandb.Image( image, caption=f"{i}: {config.validation_prompt}", ) for i, image in enumerate(images) ] } ) del pipeline torch.cuda.empty_cache() # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unwrap_model(unet) unet_lora_state_dict = convert_state_dict_to_diffusers( get_peft_model_state_dict(unet) ) if config.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) text_encoder_two = unwrap_model(text_encoder_two) text_encoder_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_one) ) text_encoder_2_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_two) ) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None StableDiffusionXLPipeline.save_lora_weights( save_directory=config.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) del unet del text_encoder_one del text_encoder_two del text_encoder_lora_layers del text_encoder_2_lora_layers torch.cuda.empty_cache() # Final inference # Make sure vae.dtype is consistent with the unet.dtype if config.mixed_precision == "fp16": vae.to(weight_dtype) # Load previous pipeline pipeline = StableDiffusionXLPipeline.from_pretrained( config.pretrained_model_name_or_path, vae=vae, revision=config.revision, variant=config.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) # load attention processors pipeline.load_lora_weights(config.output_dir) # run inference images = [] if config.validation_prompt and config.num_validation_images > 0: generator = ( torch.Generator(device=accelerator.device).manual_seed(config.seed) if config.seed else None ) images = [ pipeline( config.validation_prompt, num_inference_steps=25, generator=generator, ).images[0] for _ in range(config.num_validation_images) ] for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( "test", np_images, epoch, dataformats="NHWC" ) if tracker.name == "wandb": tracker.log( { "test": [ wandb.Image( image, caption=f"{i}: {config.validation_prompt}" ) for i, image in enumerate(images) ] } ) if config.push_to_hub: save_model_card( repo_id, images=images, base_model=config.pretrained_model_name_or_path, dataset_name=config.dataset_name, train_text_encoder=config.train_text_encoder, repo_folder=config.output_dir, vae_path=config.pretrained_vae_model_name_or_path, ) upload_folder( repo_id=repo_id, folder_path=config.output_dir, commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) accelerator.end_training() if __name__ == "__main__": main()