UltraEdit-SD3 / UltraEdit /traning /train_sd3_pix2pix.py
BleachNick's picture
upload required packages
87d40d2
import torch.nn.init as init
import argparse
from cgitb import text
import copy
import gc
import itertools
import logging
import math
import os
import random
import shutil
from tkinter import NO
import warnings
from contextlib import nullcontext
from pathlib import Path
import PIL.Image
import PIL.ImageOps
import numpy as np
from sympy import N
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast
# from transformer_sd3 import SD3Transformer2DModel
import diffusers
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
StableDiffusion3Pipeline,
SD3Transformer2DModel,
StableDiffusion3InstructPix2PixPipeline
)
from diffusers.optimization import get_scheduler
from diffusers.utils import (
check_min_version,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
import accelerate
import datasets
import PIL
import requests
import torch.nn as nn
import torch.nn.functional as F
from os.path import join
from datasets import load_dataset
from packaging import version
def load_text_encoders(class_one, class_two, class_three):
text_encoder_one = class_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = class_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
text_encoder_three = class_three.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant
)
return text_encoder_one, text_encoder_two, text_encoder_three
def tokenize_prompt(tokenizer, prompt, max_sequence_length=77):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length,
text_encoder_dtype,
prompt=None,
num_images_per_prompt=1,
device=None,
text_input_ids=None
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if text_input_ids is None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=text_encoder_dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
text_encoder_dtype,
device=None,
num_images_per_prompt: int = 1,
text_input_ids=None
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if text_input_ids is None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=text_encoder_dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length=None,
text_encoders_dtypes=[torch.float32,torch.float32,torch.float32],
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None
):
prompt = [prompt] if isinstance(prompt, str) else prompt
clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
clip_tokenizers = tokenizers[:2]
clip_text_encoders = text_encoders[:2]
clip_text_encoders_dtypes = text_encoders_dtypes[:2]
if text_input_ids_list is not None:
clip_text_input_ids_list = text_input_ids_list[:2]
else:
clip_text_input_ids_list = [None, None]
zipped_text_encoders = zip(clip_tokenizers, clip_text_encoders, clip_text_encoders_dtypes, clip_text_input_ids_list)
for tokenizer, text_encoder, clip_text_encoder_dtype, text_input_ids in zipped_text_encoders:
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
text_encoder_dtype=clip_text_encoder_dtype,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
if text_input_ids_list is not None:
t5_text_input_ids = text_input_ids_list[-1]
else:
t5_text_input_ids = None
t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1],
tokenizers[-1],
max_sequence_length,
clip_text_encoders_dtypes[-1],
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device,
text_input_ids=t5_text_input_ids
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
return prompt_embeds, pooled_prompt_embeds
logger = get_logger(__name__, log_level="INFO")
DATASET_NAME_MAPPING = {
"BleachNick/UltraEdit_500k": ("source_image", "edited_image", "edit_prompt"),
}
WANDB_TABLE_COL_NAMES = ["source_image", "edited_image", "edit_prompt"]
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--ori_model_name_or_path",
type=str,
default=None,
help="Path to ori_model_name_or_path.",
)
parser.add_argument(
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW",
)
parser.add_argument(
"--prodigy_use_bias_correction",
type=bool,
default=True,
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
)
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument(
"--prodigy_safeguard_warmup",
type=bool,
default=True,
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
"Ignored if optimizer is adamW",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_jsonl",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--original_image_column",
type=str,
default="source_image",
help="The column of the dataset containing the original image on which edits where made.",
)
parser.add_argument(
"--config_file",
type=str,
default=None,
)
parser.add_argument(
"--edited_image_column",
type=str,
default="edited_image",
help="The column of the dataset containing the edited image.",
)
parser.add_argument(
"--edit_prompt_column",
type=str,
default="edit_prompt",
help="The column of the dataset containing the edit instruction.",
)
parser.add_argument(
"--val_image_url",
type=str,
default=None,
help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
)
parser.add_argument(
'--val_mask_url',
type=str,
default=None,
help="URL to the mask image that you would like to edit (used during inference for debugging purposes).",
)
parser.add_argument(
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=1,
help=(
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--validation_step",
type=int,
default=5000,
help=(
"Run fine-tuning validation every X steps. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--top_training_data_sample",
type=int,
default=None,
help="Number of top samples to use for training, ranked by clip-sim-dit. If None, use the full dataset.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="sd3_edit",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=256,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--eval_resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="cosine",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--conditioning_dropout_prob",
type=float,
default=None,
help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--text_encoder_lr",
type=float,
default=5e-6,
help="Text encoder learning rate to use.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--train_text_encoder",
action="store_true"
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--do_mask", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--mask_column",
type=str,
default="mask_image",
help="The column of the dataset containing the original image`s mask.",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_jsonl is None:
raise ValueError("Need either a dataset name or a training folder.")
# default to using the same revision for the non-ema model if not specified
return args
def combine_rgb_and_mask_to_rgba(rgb_image, mask_image):
# Ensure the input images are the same size
if rgb_image.size != mask_image.size:
raise ValueError("The RGB image and the mask image must have the same dimensions")
# Convert the mask image to 'L' mode (grayscale) if it is not
if mask_image.mode != 'L':
mask_image = mask_image.convert('L')
# Split the RGB image into its three channels
r, g, b = rgb_image.split()
# Combine the RGB channels with the mask to form an RGBA image
rgba_image = Image.merge("RGBA", (r, g, b, mask_image))
return rgba_image
def convert_to_np(image, resolution):
try:
if isinstance(image, str):
if image == "NONE":
image = PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255))
else:
image = PIL.Image.open(image)
image = image.convert("RGB").resize((resolution, resolution))
return np.array(image).transpose(2, 0, 1)
except Exception as e:
print("Load error", image)
print(e)
# New blank image
image = PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255))
return np.array(image).transpose(2, 0, 1)
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModelWithProjection":
from transformers import CLIPTextModelWithProjection
return CLIPTextModelWithProjection
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")
def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
from accelerate import DistributedDataParallelKwargs as DDPK
kwargs = DDPK(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
)
if torch.backends.mps.is_available():
accelerator.native_amp = False
def download_image(path_or_url,resolution=512):
# Check if path_or_url is a local file path
if path_or_url is None:
# return a white RBG image image
return PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255))
if os.path.exists(path_or_url):
image = Image.open(path_or_url).convert("RGB").resize((resolution, resolution))
else:
image = Image.open(requests.get(path_or_url, stream=True).raw).convert("RGB")
image = PIL.ImageOps.exif_transpose(image)
return image
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models.
# Load the tokenizers
tokenizer_one = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
tokenizer_two = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
)
tokenizer_three = T5TokenizerFast.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_3",
revision=args.revision,
)
# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
)
text_encoder_cls_three = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
)
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
transformer = SD3Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
)
# TODO
logger.info("Initializing the new channel of DIT from the pretrained DIT.")
in_channels = int(1.5 * transformer.config.in_channels) if args.do_mask else 2 * transformer.config.in_channels # 48 for mask
out_channels = transformer.pos_embed.proj.out_channels
load_num_channel = transformer.config.in_channels
print("Do mask",args.do_mask)
print("new in_channels",in_channels)
print("load_num_channel",load_num_channel)
transformer.register_to_config(in_channels=in_channels)
print("transformer.pos_embed.proj.weight.shape", transformer.pos_embed.proj.weight.shape)
print("load_num_channel", load_num_channel)
with torch.no_grad():
new_proj = nn.Conv2d(
in_channels, out_channels, kernel_size=(transformer.config.patch_size, transformer.config.patch_size),
stride=transformer.config.patch_size, bias=True
)
print("new_proj", new_proj)
new_proj.weight.zero_()
# init.kaiming_normal_(new_proj.weight, mode='fan_out', nonlinearity='relu')
# if new_proj.bias is not None and transformer.pos_embed.proj.bias is not None:
# new_proj.bias.copy_(transformer.pos_embed.proj.bias)
# else:
# if new_proj.bias is not None:
# new_proj.bias.zero_()
new_proj = new_proj.to(transformer.pos_embed.proj.weight.dtype)
new_proj.weight[:, :load_num_channel, :, :].copy_(transformer.pos_embed.proj.weight)
new_proj.bias.copy_(transformer.pos_embed.proj.bias)
print("new_proj", new_proj.weight.shape)
print("transformer.pos_embed.proj", transformer.pos_embed.proj.weight.shape)
transformer.pos_embed.proj = new_proj
for param in transformer.parameters():
param.requires_grad = True
transformer.requires_grad_(True)
vae.requires_grad_(False)
if args.train_text_encoder:
text_encoder_one.requires_grad_(True)
text_encoder_two.requires_grad_(True)
text_encoder_three.requires_grad_(True)
else:
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
text_encoder_three.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
vae.to(accelerator.device, dtype=torch.float32)
if not args.train_text_encoder:
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
text_encoder_three.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
text_encoder_three.gradient_checkpointing_enable()
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for i, model in enumerate(models):
if isinstance(unwrap_model(model), SD3Transformer2DModel):
unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer"))
elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
if isinstance(unwrap_model(model), CLIPTextModelWithProjection):
hidden_size = unwrap_model(model).config.hidden_size
if hidden_size == 768:
unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder"))
elif hidden_size == 1280:
unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2"))
else:
unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3"))
else:
raise ValueError(f"Wrong model supplied: {type(model)=}.")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
for _ in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
if isinstance(unwrap_model(model), SD3Transformer2DModel):
load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
try:
load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder")
model(**load_model.config)
model.load_state_dict(load_model.state_dict())
except Exception:
try:
load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2")
model(**load_model.config)
model.load_state_dict(load_model.state_dict())
except Exception:
try:
load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3")
model(**load_model.config)
model.load_state_dict(load_model.state_dict())
except Exception:
raise ValueError(f"Couldn't load the model of type: ({type(model)}).")
else:
raise ValueError(f"Unsupported model found: {type(model)=}")
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate}
if args.train_text_encoder:
# different learning rate for text encoder and unet
text_parameters_one_with_lr = {
"params": text_encoder_one.parameters(),
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_parameters_two_with_lr = {
"params": text_encoder_two.parameters(),
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_parameters_three_with_lr = {
"params": text_encoder_three.parameters(),
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
transformer_parameters_with_lr,
text_parameters_one_with_lr,
text_parameters_two_with_lr,
text_parameters_three_with_lr,
]
else:
params_to_optimize = [transformer_parameters_with_lr]
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
# Initialize the optimizer
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
params_to_optimize[3]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
text_encoders_dtypes = [text_encoder_one.dtype, text_encoder_two.dtype, text_encoder_three.dtype]
if not args.train_text_encoder:
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
def compute_text_embeddings(prompt, text_encoders, tokenizers,text_encoders_dtypes):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length, text_encoders_dtypes
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
if args.train_data_jsonl is not None:
dataset = load_dataset(
"json",
data_files=args.train_data_jsonl,
cache_dir=args.cache_dir,
# split="train"
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.original_image_column is None:
original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
original_image_column = args.original_image_column
if original_image_column not in column_names:
raise ValueError(
f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.edit_prompt_column is None:
edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
edit_prompt_column = args.edit_prompt_column
if edit_prompt_column not in column_names:
raise ValueError(
f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
)
if args.edited_image_column is None:
edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
else:
edited_image_column = args.edited_image_column
if edited_image_column not in column_names:
raise ValueError(
f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
# def tokenize_captions(captions):
# inputs = tokenizer(
# captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
# )
# return inputs.input_ids
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
]
)
def preprocess_images(examples):
original_images = np.concatenate(
[convert_to_np(image, args.resolution) for image in examples[original_image_column]]
)
edited_images = np.concatenate(
[convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
)
if args.do_mask:
mask_images = np.concatenate(
[convert_to_np(image, args.resolution) for image in examples[args.mask_column]]
)
# We need to ensure that the original and the edited images undergo the same
# augmentation transforms.
images = np.concatenate([original_images, edited_images, mask_images])
images = torch.tensor(images)
images = 2 * (images / 255) - 1
# mask_index = torch.tensor([image == "NONE" for image in examples[args.mask_column]],dtype=torch.bool)
# return train_transforms(images),mask_index
return train_transforms(images)
# We need to ensure that the original and the edited images undergo the same
# augmentation transforms.
images = np.concatenate([original_images, edited_images])
images = torch.tensor(images)
images = 2 * (images / 255) - 1
return train_transforms(images)
def preprocess_train(examples):
# Preprocess images.
# Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape
# them accordingly.
preprocessed_images = preprocess_images(examples)
if not args.do_mask:
# preprocessed_images = preprocess_images(examples)
original_images, edited_images = preprocessed_images.chunk(2)
else:
# preprocessed_images = preprocess_images(examples)
# preprocessed_images,mask_index = preprocess_images(examples)
original_images, edited_images, mask_images = preprocessed_images.chunk(3)
mask_images = mask_images.reshape(-1, 3, args.resolution, args.resolution)
# examples["mask_index"] = mask_index
examples["mask_pixel_values"] = mask_images
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
examples["original_pixel_values"] = original_images
examples["edited_pixel_values"] = edited_images
# Preprocess the captions.
# captions = list(examples[edit_prompt_column])
# examples[edit_prompt_column] = captions
return examples
with accelerator.main_process_first():
if args.top_training_data_sample is not None:
dataset["train"] = dataset["train"].select(range(args.top_training_data_sample)).shuffle(seed=args.seed)
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
prompts = [example[edit_prompt_column] for example in examples]
if args.do_mask:
mask_pixel_values = torch.stack([example["mask_pixel_values"] for example in examples])
mask_pixel_values = mask_pixel_values.to(memory_format=torch.contiguous_format).float()
return {
"original_pixel_values": original_pixel_values,
"edited_pixel_values": edited_pixel_values,
edit_prompt_column: prompts,
"mask_pixel_values": mask_pixel_values,
}
else:
return {
"original_pixel_values": original_pixel_values,
"edited_pixel_values": edited_pixel_values,
edit_prompt_column: prompts,
}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
transformer,
text_encoder_one,
text_encoder_two,
text_encoder_three,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
transformer,
text_encoder_one,
text_encoder_two,
text_encoder_three,
optimizer,
train_dataloader,
lr_scheduler,
)
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
if accelerator.is_main_process:
pretrained_path = args.pretrained_model_name_or_path
pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained(
pretrained_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
text_encoder_3=accelerator.unwrap_model(text_encoder_three),
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
generator = torch.Generator(device=accelerator.device).manual_seed(
args.seed) if args.seed else None
if args.do_mask:
original_image = download_image(args.val_image_url, args.eval_resolution)
mask_image = download_image(args.val_mask_url, args.eval_resolution)
else:
original_image = download_image(args.val_image_url, args.eval_resolution)
mask_image = None
edited_images = []
with torch.autocast(
str(accelerator.device).replace(":0", ""),
enabled=(accelerator.mixed_precision == "fp16") | (
accelerator.mixed_precision == "bf16")
):
for i in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
mask_img=mask_image,
num_inference_steps=50,
image_guidance_scale=1.5,
guidance_scale=7.5,
generator=generator,
).images[0]
)
path = join(args.output_dir, f"start_test")
os.makedirs(path, exist_ok=True)
original_image.save(join(path, f"original.jpg"))
for idx, edited_image in enumerate(edited_images):
edited_image.save(join(path, f"sample_{idx}.jpg"))
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
print('=========num_update_steps_per_epoch==========', num_update_steps_per_epoch)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("instruct-pix2pix_sd3", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_global_step = global_step * args.gradient_accumulation_steps
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(0, args.max_train_steps), initial=initial_global_step, desc="Steps",
disable=not accelerator.is_local_main_process)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
# with torch.autograd.set_detect_anomaly(True):
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
text_encoder_three.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
# if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
# if step % args.gradient_accumulation_steps == 0:
# progress_bar.update(1)
# continue
models_to_accumulate = [transformer]
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three])
with accelerator.accumulate(models_to_accumulate):
# We want to learn the denoising process w.r.t the edited images which
# are conditioned on the original image (which was edited) and the edit instruction.
# So, first, convert images to latent space.]
pixel_values = batch["edited_pixel_values"].to(dtype=vae.dtype)
prompt = batch[edit_prompt_column]
if not args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompt, text_encoders, tokenizers,text_encoders_dtypes
)
else:
tokens_one = tokenize_prompt(tokenizer_one, prompt)
tokens_two = tokenize_prompt(tokenizer_two, prompt)
tokens_three = tokenize_prompt(tokenizer_three, prompt, args.max_sequence_length)
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = latents.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
u = torch.rand(size=(bsz,), device="cpu")
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# Get the additional image embedding for conditioning.
# Instead of getting a diagonal Gaussian here, we simply take the mode.
original_image_embeds = vae.encode(batch["original_pixel_values"].to(vae.dtype)).latent_dist.mode()
concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1)
if args.do_mask:
mask_embeds = vae.encode(batch["mask_pixel_values"].to(vae.dtype)).latent_dist.mode()
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, mask_embeds], dim=1)
# Predict the noise residual
if not args.train_text_encoder:
model_pred = transformer(
hidden_states=concatenated_noisy_latents,
timestep=timesteps,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
# mask_index = mask_index
)[0]
else:
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three],
prompt=prompt,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
max_sequence_length=args.max_sequence_length,
text_encoders_dtypes = text_encoders_dtypes
)
model_pred = transformer(
hidden_states=concatenated_noisy_latents,
timestep=timesteps,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
mask_index=mask_index
)[0]
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas ** -2.0).float()
elif args.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas ** 2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
target = latents
# Conditioning dropout to support classifier-free guidance during inference. For more details
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
# Concatenate the `original_image_embeds` with the `noisy_latents`.
# Get the target for loss depending on the prediction type
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(
transformer.parameters(),
text_encoder_one.parameters(),
text_encoder_two.parameters(),
text_encoder_three.parameters(),
)
if args.train_text_encoder
else transformer.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if accelerator.is_main_process:
if (
(args.val_image_url is not None)
and (args.validation_prompt is not None)
and (global_step % args.validation_step == 0)
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
# if not args.train_text_encoder:
# text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
# text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
# )
if args.do_mask:
pretrained_path = args.ori_model_name_or_path
else:
pretrained_path = args.pretrained_model_name_or_path
pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained(
pretrained_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
text_encoder_3=accelerator.unwrap_model(text_encoder_three),
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
generator = torch.Generator(device=accelerator.device).manual_seed(
args.seed) if args.seed else None
# run inference
if args.do_mask:
original_image = download_image(args.val_image_url,args.eval_resolution)
mask_image = download_image(args.val_mask_url,args.eval_resolution)
else:
original_image = download_image(args.val_image_url,args.eval_resolution)
mask_image = None
edited_images = []
with torch.autocast(
str(accelerator.device).replace(":0", ""),
enabled=(accelerator.mixed_precision == "fp16") | (
accelerator.mixed_precision == "bf16")
):
for i in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
mask_img=mask_image,
num_inference_steps=50,
image_guidance_scale=1.5,
guidance_scale=7.5,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
path = join(args.output_dir, f"eval_{global_step}")
os.makedirs(path, exist_ok=True)
original_image.save(join(path, f"original.jpg"))
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for idx, edited_image in enumerate(edited_images):
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
# save in the dir as well
tracker.log({"validation": wandb_table})
for idx, edited_image in enumerate(edited_images):
edited_image.save(join(path, f"sample_{idx}.jpg"))
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_three = unwrap_model(text_encoder_three)
pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=transformer,
text_encoder=text_encoder_one,
text_encoder_2=text_encoder_two,
text_encoder_3=text_encoder_three,
)
else:
pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path, transformer=transformer
)
pipeline.save_pretrained(args.output_dir)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main()