|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import logging | 
					
						
						|  | import math | 
					
						
						|  | import os | 
					
						
						|  | import shutil | 
					
						
						|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  | import accelerate | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import torch.utils.checkpoint | 
					
						
						|  | import transformers | 
					
						
						|  | from accelerate import Accelerator | 
					
						
						|  | from accelerate.logging import get_logger | 
					
						
						|  | from accelerate.utils import ProjectConfiguration, set_seed | 
					
						
						|  | from packaging import version | 
					
						
						|  | from tqdm.auto import tqdm | 
					
						
						|  |  | 
					
						
						|  | import diffusers | 
					
						
						|  | from diffusers import ( | 
					
						
						|  | AutoencoderKL, | 
					
						
						|  | DDPMScheduler, | 
					
						
						|  | EulerDiscreteScheduler, | 
					
						
						|  | StableDiffusionGLIGENPipeline, | 
					
						
						|  | UNet2DConditionModel, | 
					
						
						|  | ) | 
					
						
						|  | from diffusers.optimization import get_scheduler | 
					
						
						|  | from diffusers.utils import is_wandb_available, make_image_grid | 
					
						
						|  | from diffusers.utils.import_utils import is_xformers_available | 
					
						
						|  | from diffusers.utils.torch_utils import is_compiled_module | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if is_wandb_available(): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger = get_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def log_validation(vae, text_encoder, tokenizer, unet, noise_scheduler, args, accelerator, step, weight_dtype): | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | print("generate test images...") | 
					
						
						|  | unet = accelerator.unwrap_model(unet) | 
					
						
						|  | vae.to(accelerator.device, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  | pipeline = StableDiffusionGLIGENPipeline( | 
					
						
						|  | vae, | 
					
						
						|  | text_encoder, | 
					
						
						|  | tokenizer, | 
					
						
						|  | unet, | 
					
						
						|  | EulerDiscreteScheduler.from_config(noise_scheduler.config), | 
					
						
						|  | safety_checker=None, | 
					
						
						|  | feature_extractor=None, | 
					
						
						|  | ) | 
					
						
						|  | pipeline = pipeline.to(accelerator.device) | 
					
						
						|  | pipeline.set_progress_bar_config(disable=not accelerator.is_main_process) | 
					
						
						|  | if args.enable_xformers_memory_efficient_attention: | 
					
						
						|  | pipeline.enable_xformers_memory_efficient_attention() | 
					
						
						|  |  | 
					
						
						|  | if args.seed is None: | 
					
						
						|  | generator = None | 
					
						
						|  | else: | 
					
						
						|  | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | 
					
						
						|  |  | 
					
						
						|  | prompt = "A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky" | 
					
						
						|  | boxes = [ | 
					
						
						|  | [0.041015625, 0.548828125, 0.453125, 0.859375], | 
					
						
						|  | [0.525390625, 0.552734375, 0.93359375, 0.865234375], | 
					
						
						|  | [0.12890625, 0.015625, 0.412109375, 0.279296875], | 
					
						
						|  | [0.578125, 0.08203125, 0.857421875, 0.27734375], | 
					
						
						|  | ] | 
					
						
						|  | gligen_phrases = ["a green car", "a blue truck", "a red air balloon", "a bird"] | 
					
						
						|  | images = pipeline( | 
					
						
						|  | prompt=prompt, | 
					
						
						|  | gligen_phrases=gligen_phrases, | 
					
						
						|  | gligen_boxes=boxes, | 
					
						
						|  | gligen_scheduled_sampling_beta=1.0, | 
					
						
						|  | output_type="pil", | 
					
						
						|  | num_inference_steps=50, | 
					
						
						|  | negative_prompt="artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate", | 
					
						
						|  | num_images_per_prompt=4, | 
					
						
						|  | generator=generator, | 
					
						
						|  | ).images | 
					
						
						|  | os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True) | 
					
						
						|  | make_image_grid(images, 1, 4).save( | 
					
						
						|  | os.path.join(args.output_dir, "images", f"generated-images-{step:06d}-{accelerator.process_index:02d}.png") | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | vae.to(accelerator.device, dtype=weight_dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def parse_args(input_args=None): | 
					
						
						|  | parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--data_path", | 
					
						
						|  | type=str, | 
					
						
						|  | default="coco_train2017.pth", | 
					
						
						|  | help="Path to training dataset.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--image_path", | 
					
						
						|  | type=str, | 
					
						
						|  | default="coco_train2017.pth", | 
					
						
						|  | help="Path to training images.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--output_dir", | 
					
						
						|  | type=str, | 
					
						
						|  | default="controlnet-model", | 
					
						
						|  | help="The output directory where the model predictions and checkpoints will be written.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--resolution", | 
					
						
						|  | type=int, | 
					
						
						|  | default=512, | 
					
						
						|  | help=( | 
					
						
						|  | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | 
					
						
						|  | " resolution" | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument("--num_train_epochs", type=int, default=1) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--max_train_steps", | 
					
						
						|  | type=int, | 
					
						
						|  | default=None, | 
					
						
						|  | help="Total number of training steps to perform.  If provided, overrides num_train_epochs.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--checkpointing_steps", | 
					
						
						|  | type=int, | 
					
						
						|  | default=500, | 
					
						
						|  | help=( | 
					
						
						|  | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " | 
					
						
						|  | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." | 
					
						
						|  | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." | 
					
						
						|  | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" | 
					
						
						|  | "instructions." | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--checkpoints_total_limit", | 
					
						
						|  | type=int, | 
					
						
						|  | default=None, | 
					
						
						|  | help=("Max number of checkpoints to store."), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--resume_from_checkpoint", | 
					
						
						|  | type=str, | 
					
						
						|  | default=None, | 
					
						
						|  | help=( | 
					
						
						|  | "Whether training should be resumed from a previous checkpoint. Use a path saved by" | 
					
						
						|  | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--gradient_accumulation_steps", | 
					
						
						|  | type=int, | 
					
						
						|  | default=1, | 
					
						
						|  | help="Number of updates steps to accumulate before performing a backward/update pass.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--gradient_checkpointing", | 
					
						
						|  | action="store_true", | 
					
						
						|  | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--learning_rate", | 
					
						
						|  | type=float, | 
					
						
						|  | default=5e-6, | 
					
						
						|  | help="Initial learning rate (after the potential warmup period) to use.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--scale_lr", | 
					
						
						|  | action="store_true", | 
					
						
						|  | default=False, | 
					
						
						|  | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--lr_scheduler", | 
					
						
						|  | type=str, | 
					
						
						|  | default="constant", | 
					
						
						|  | help=( | 
					
						
						|  | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 
					
						
						|  | ' "constant", "constant_with_warmup"]' | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--lr_num_cycles", | 
					
						
						|  | type=int, | 
					
						
						|  | default=1, | 
					
						
						|  | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--dataloader_num_workers", | 
					
						
						|  | type=int, | 
					
						
						|  | default=0, | 
					
						
						|  | help=( | 
					
						
						|  | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") | 
					
						
						|  | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") | 
					
						
						|  | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") | 
					
						
						|  | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") | 
					
						
						|  | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--logging_dir", | 
					
						
						|  | type=str, | 
					
						
						|  | default="logs", | 
					
						
						|  | help=( | 
					
						
						|  | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" | 
					
						
						|  | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--allow_tf32", | 
					
						
						|  | action="store_true", | 
					
						
						|  | help=( | 
					
						
						|  | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" | 
					
						
						|  | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--report_to", | 
					
						
						|  | type=str, | 
					
						
						|  | default="tensorboard", | 
					
						
						|  | help=( | 
					
						
						|  | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | 
					
						
						|  | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--mixed_precision", | 
					
						
						|  | type=str, | 
					
						
						|  | default=None, | 
					
						
						|  | choices=["no", "fp16", "bf16"], | 
					
						
						|  | help=( | 
					
						
						|  | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | 
					
						
						|  | " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the" | 
					
						
						|  | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--set_grads_to_none", | 
					
						
						|  | action="store_true", | 
					
						
						|  | help=( | 
					
						
						|  | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" | 
					
						
						|  | " behaviors, so disable this argument if it causes any problems. More info:" | 
					
						
						|  | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--tracker_project_name", | 
					
						
						|  | type=str, | 
					
						
						|  | default="train_controlnet", | 
					
						
						|  | help=( | 
					
						
						|  | "The `project_name` argument passed to Accelerator.init_trackers for" | 
					
						
						|  | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  | return args | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(args): | 
					
						
						|  | logging_dir = Path(args.output_dir, args.logging_dir) | 
					
						
						|  |  | 
					
						
						|  | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) | 
					
						
						|  |  | 
					
						
						|  | accelerator = Accelerator( | 
					
						
						|  | gradient_accumulation_steps=args.gradient_accumulation_steps, | 
					
						
						|  | mixed_precision=args.mixed_precision, | 
					
						
						|  | log_with=args.report_to, | 
					
						
						|  | project_config=accelerator_project_config, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if torch.backends.mps.is_available(): | 
					
						
						|  | accelerator.native_amp = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | 
					
						
						|  | datefmt="%m/%d/%Y %H:%M:%S", | 
					
						
						|  | level=logging.INFO, | 
					
						
						|  | ) | 
					
						
						|  | logger.info(accelerator.state, main_process_only=False) | 
					
						
						|  | if accelerator.is_local_main_process: | 
					
						
						|  | transformers.utils.logging.set_verbosity_warning() | 
					
						
						|  | diffusers.utils.logging.set_verbosity_info() | 
					
						
						|  | else: | 
					
						
						|  | transformers.utils.logging.set_verbosity_error() | 
					
						
						|  | diffusers.utils.logging.set_verbosity_error() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.seed is not None: | 
					
						
						|  | set_seed(args.seed) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | if args.output_dir is not None: | 
					
						
						|  | os.makedirs(args.output_dir, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from transformers import CLIPTextModel, CLIPTokenizer | 
					
						
						|  |  | 
					
						
						|  | pretrained_model_name_or_path = "masterful/gligen-1-4-generation-text-box" | 
					
						
						|  | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") | 
					
						
						|  | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") | 
					
						
						|  | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder") | 
					
						
						|  |  | 
					
						
						|  | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") | 
					
						
						|  | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def unwrap_model(model): | 
					
						
						|  | model = accelerator.unwrap_model(model) | 
					
						
						|  | model = model._orig_mod if is_compiled_module(model) else model | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | 
					
						
						|  |  | 
					
						
						|  | def save_model_hook(models, weights, output_dir): | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | i = len(weights) - 1 | 
					
						
						|  |  | 
					
						
						|  | while len(weights) > 0: | 
					
						
						|  | weights.pop() | 
					
						
						|  | model = models[i] | 
					
						
						|  |  | 
					
						
						|  | sub_dir = "unet" | 
					
						
						|  | model.save_pretrained(os.path.join(output_dir, sub_dir)) | 
					
						
						|  |  | 
					
						
						|  | i -= 1 | 
					
						
						|  |  | 
					
						
						|  | def load_model_hook(models, input_dir): | 
					
						
						|  | while len(models) > 0: | 
					
						
						|  |  | 
					
						
						|  | model = models.pop() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | load_model = unet.from_pretrained(input_dir, subfolder="unet") | 
					
						
						|  | model.register_to_config(**load_model.config) | 
					
						
						|  |  | 
					
						
						|  | model.load_state_dict(load_model.state_dict()) | 
					
						
						|  | del load_model | 
					
						
						|  |  | 
					
						
						|  | accelerator.register_save_state_pre_hook(save_model_hook) | 
					
						
						|  | accelerator.register_load_state_pre_hook(load_model_hook) | 
					
						
						|  |  | 
					
						
						|  | vae.requires_grad_(False) | 
					
						
						|  | unet.requires_grad_(False) | 
					
						
						|  | text_encoder.requires_grad_(False) | 
					
						
						|  |  | 
					
						
						|  | if args.enable_xformers_memory_efficient_attention: | 
					
						
						|  | if is_xformers_available(): | 
					
						
						|  | import xformers | 
					
						
						|  |  | 
					
						
						|  | xformers_version = version.parse(xformers.__version__) | 
					
						
						|  | if xformers_version == version.parse("0.0.16"): | 
					
						
						|  | logger.warning( | 
					
						
						|  | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | 
					
						
						|  | ) | 
					
						
						|  | unet.enable_xformers_memory_efficient_attention() | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("xformers is not available. Make sure it is installed correctly") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | low_precision_error_string = ( | 
					
						
						|  | " Please make sure to always have all model weights in full float32 precision when starting training - even if" | 
					
						
						|  | " doing mixed precision training, copy of the weights should still be float32." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if unwrap_model(unet).dtype != torch.float32: | 
					
						
						|  | raise ValueError(f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.allow_tf32: | 
					
						
						|  | torch.backends.cuda.matmul.allow_tf32 = True | 
					
						
						|  |  | 
					
						
						|  | if args.scale_lr: | 
					
						
						|  | args.learning_rate = ( | 
					
						
						|  | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | optimizer_class = torch.optim.AdamW | 
					
						
						|  |  | 
					
						
						|  | for n, m in unet.named_modules(): | 
					
						
						|  | if ("fuser" in n) or ("position_net" in n): | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  |  | 
					
						
						|  | if isinstance(m, (nn.Linear, nn.LayerNorm)): | 
					
						
						|  | m.reset_parameters() | 
					
						
						|  | params_to_optimize = [] | 
					
						
						|  | for n, p in unet.named_parameters(): | 
					
						
						|  | if ("fuser" in n) or ("position_net" in n): | 
					
						
						|  | p.requires_grad = True | 
					
						
						|  | params_to_optimize.append(p) | 
					
						
						|  | optimizer = optimizer_class( | 
					
						
						|  | params_to_optimize, | 
					
						
						|  | lr=args.learning_rate, | 
					
						
						|  | betas=(args.adam_beta1, args.adam_beta2), | 
					
						
						|  | weight_decay=args.adam_weight_decay, | 
					
						
						|  | eps=args.adam_epsilon, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | from dataset import COCODataset | 
					
						
						|  |  | 
					
						
						|  | train_dataset = COCODataset( | 
					
						
						|  | data_path=args.data_path, | 
					
						
						|  | image_path=args.image_path, | 
					
						
						|  | tokenizer=tokenizer, | 
					
						
						|  | image_size=args.resolution, | 
					
						
						|  | max_boxes_per_data=30, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print("num samples: ", len(train_dataset)) | 
					
						
						|  |  | 
					
						
						|  | train_dataloader = torch.utils.data.DataLoader( | 
					
						
						|  | train_dataset, | 
					
						
						|  | shuffle=True, | 
					
						
						|  |  | 
					
						
						|  | batch_size=args.train_batch_size, | 
					
						
						|  | num_workers=args.dataloader_num_workers, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | overrode_max_train_steps = False | 
					
						
						|  | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 
					
						
						|  | if args.max_train_steps is None: | 
					
						
						|  | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 
					
						
						|  | overrode_max_train_steps = True | 
					
						
						|  |  | 
					
						
						|  | lr_scheduler = get_scheduler( | 
					
						
						|  | args.lr_scheduler, | 
					
						
						|  | optimizer=optimizer, | 
					
						
						|  | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, | 
					
						
						|  | num_training_steps=args.max_train_steps * accelerator.num_processes, | 
					
						
						|  | num_cycles=args.lr_num_cycles, | 
					
						
						|  | power=args.lr_power, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | 
					
						
						|  | unet, optimizer, train_dataloader, lr_scheduler | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | weight_dtype = torch.float32 | 
					
						
						|  | if accelerator.mixed_precision == "fp16": | 
					
						
						|  | weight_dtype = torch.float16 | 
					
						
						|  | elif accelerator.mixed_precision == "bf16": | 
					
						
						|  | weight_dtype = torch.bfloat16 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vae.to(accelerator.device, dtype=weight_dtype) | 
					
						
						|  |  | 
					
						
						|  | unet.to(accelerator.device, dtype=torch.float32) | 
					
						
						|  | text_encoder.to(accelerator.device, dtype=weight_dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 
					
						
						|  | if overrode_max_train_steps: | 
					
						
						|  | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 
					
						
						|  |  | 
					
						
						|  | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | tracker_config = dict(vars(args)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | global_step = 0 | 
					
						
						|  | first_epoch = 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.resume_from_checkpoint: | 
					
						
						|  | if args.resume_from_checkpoint != "latest": | 
					
						
						|  | path = os.path.basename(args.resume_from_checkpoint) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | dirs = os.listdir(args.output_dir) | 
					
						
						|  | dirs = [d for d in dirs if d.startswith("checkpoint")] | 
					
						
						|  | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | 
					
						
						|  | path = dirs[-1] if len(dirs) > 0 else None | 
					
						
						|  |  | 
					
						
						|  | if path is None: | 
					
						
						|  | accelerator.print( | 
					
						
						|  | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." | 
					
						
						|  | ) | 
					
						
						|  | args.resume_from_checkpoint = None | 
					
						
						|  | initial_global_step = 0 | 
					
						
						|  | else: | 
					
						
						|  | accelerator.print(f"Resuming from checkpoint {path}") | 
					
						
						|  | accelerator.load_state(os.path.join(args.output_dir, path)) | 
					
						
						|  | global_step = int(path.split("-")[1]) | 
					
						
						|  |  | 
					
						
						|  | initial_global_step = global_step | 
					
						
						|  | first_epoch = global_step // num_update_steps_per_epoch | 
					
						
						|  | else: | 
					
						
						|  | initial_global_step = 0 | 
					
						
						|  |  | 
					
						
						|  | progress_bar = tqdm( | 
					
						
						|  | range(0, args.max_train_steps), | 
					
						
						|  | initial=initial_global_step, | 
					
						
						|  | desc="Steps", | 
					
						
						|  |  | 
					
						
						|  | disable=not accelerator.is_local_main_process, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | log_validation( | 
					
						
						|  | vae, | 
					
						
						|  | text_encoder, | 
					
						
						|  | tokenizer, | 
					
						
						|  | unet, | 
					
						
						|  | noise_scheduler, | 
					
						
						|  | args, | 
					
						
						|  | accelerator, | 
					
						
						|  | global_step, | 
					
						
						|  | weight_dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for epoch in range(first_epoch, args.num_train_epochs): | 
					
						
						|  | for step, batch in enumerate(train_dataloader): | 
					
						
						|  | with accelerator.accumulate(unet): | 
					
						
						|  |  | 
					
						
						|  | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() | 
					
						
						|  | latents = latents * vae.config.scaling_factor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise = torch.randn_like(latents) | 
					
						
						|  | bsz = latents.shape[0] | 
					
						
						|  |  | 
					
						
						|  | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | 
					
						
						|  | timesteps = timesteps.long() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  |  | 
					
						
						|  | encoder_hidden_states = text_encoder( | 
					
						
						|  | batch["caption"]["input_ids"].squeeze(1), | 
					
						
						|  |  | 
					
						
						|  | return_dict=False, | 
					
						
						|  | )[0] | 
					
						
						|  |  | 
					
						
						|  | cross_attention_kwargs = {} | 
					
						
						|  | cross_attention_kwargs["gligen"] = { | 
					
						
						|  | "boxes": batch["boxes"], | 
					
						
						|  | "positive_embeddings": batch["text_embeddings_before_projection"], | 
					
						
						|  | "masks": batch["masks"], | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | model_pred = unet( | 
					
						
						|  | noisy_latents, | 
					
						
						|  | timesteps, | 
					
						
						|  | encoder_hidden_states=encoder_hidden_states, | 
					
						
						|  | cross_attention_kwargs=cross_attention_kwargs, | 
					
						
						|  | return_dict=False, | 
					
						
						|  | )[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if noise_scheduler.config.prediction_type == "epsilon": | 
					
						
						|  | target = noise | 
					
						
						|  | elif noise_scheduler.config.prediction_type == "v_prediction": | 
					
						
						|  | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 
					
						
						|  | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 
					
						
						|  |  | 
					
						
						|  | accelerator.backward(loss) | 
					
						
						|  | if accelerator.sync_gradients: | 
					
						
						|  | accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) | 
					
						
						|  | optimizer.step() | 
					
						
						|  | lr_scheduler.step() | 
					
						
						|  | optimizer.zero_grad(set_to_none=args.set_grads_to_none) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if accelerator.sync_gradients: | 
					
						
						|  | progress_bar.update(1) | 
					
						
						|  | global_step += 1 | 
					
						
						|  |  | 
					
						
						|  | if global_step % args.checkpointing_steps == 0: | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  |  | 
					
						
						|  | if args.checkpoints_total_limit is not None: | 
					
						
						|  | checkpoints = os.listdir(args.output_dir) | 
					
						
						|  | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | 
					
						
						|  | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(checkpoints) >= args.checkpoints_total_limit: | 
					
						
						|  | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 | 
					
						
						|  | removing_checkpoints = checkpoints[0:num_to_remove] | 
					
						
						|  |  | 
					
						
						|  | logger.info( | 
					
						
						|  | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | 
					
						
						|  | ) | 
					
						
						|  | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | 
					
						
						|  |  | 
					
						
						|  | for removing_checkpoint in removing_checkpoints: | 
					
						
						|  | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) | 
					
						
						|  | shutil.rmtree(removing_checkpoint) | 
					
						
						|  |  | 
					
						
						|  | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step:06d}") | 
					
						
						|  | accelerator.save_state(save_path) | 
					
						
						|  | logger.info(f"Saved state to {save_path}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | log_validation( | 
					
						
						|  | vae, | 
					
						
						|  | text_encoder, | 
					
						
						|  | tokenizer, | 
					
						
						|  | unet, | 
					
						
						|  | noise_scheduler, | 
					
						
						|  | args, | 
					
						
						|  | accelerator, | 
					
						
						|  | global_step, | 
					
						
						|  | weight_dtype, | 
					
						
						|  | ) | 
					
						
						|  | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | 
					
						
						|  | progress_bar.set_postfix(**logs) | 
					
						
						|  | accelerator.log(logs, step=global_step) | 
					
						
						|  |  | 
					
						
						|  | if global_step >= args.max_train_steps: | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | accelerator.wait_for_everyone() | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | unet = unwrap_model(unet) | 
					
						
						|  | unet.save_pretrained(args.output_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | accelerator.end_training() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | args = parse_args() | 
					
						
						|  | main(args) | 
					
						
						|  |  |