|
import json |
|
from time import time |
|
import argparse |
|
import logging |
|
import os |
|
from pathlib import Path |
|
import math |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from copy import deepcopy |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torchvision import transforms |
|
|
|
from accelerate import Accelerator |
|
from accelerate.utils import ProjectConfiguration, set_seed |
|
from diffusers.optimization import get_scheduler |
|
from accelerate.utils import DistributedType |
|
from peft import LoraConfig, set_peft_model_state_dict, PeftModel, get_peft_model |
|
from peft.utils import get_peft_model_state_dict |
|
from huggingface_hub import snapshot_download |
|
from safetensors.torch import save_file |
|
|
|
from diffusers.models import AutoencoderKL |
|
|
|
from OmniGen import OmniGen, OmniGenProcessor |
|
from OmniGen.train_helper import DatasetFromJson, TrainDataCollator |
|
from OmniGen.train_helper import training_losses |
|
from OmniGen.utils import ( |
|
create_logger, |
|
update_ema, |
|
requires_grad, |
|
center_crop_arr, |
|
crop_arr, |
|
vae_encode, |
|
vae_encode_list |
|
) |
|
|
|
def main(args): |
|
|
|
from accelerate import DistributedDataParallelKwargs as DDPK |
|
kwargs = DDPK(find_unused_parameters=False) |
|
accelerator = Accelerator( |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
mixed_precision=args.mixed_precision, |
|
log_with=args.report_to, |
|
project_dir=args.results_dir, |
|
kwargs_handlers=[kwargs], |
|
) |
|
device = accelerator.device |
|
accelerator.init_trackers("tensorboard_log", config=args.__dict__) |
|
|
|
|
|
checkpoint_dir = f"{args.results_dir}/checkpoints" |
|
logger = create_logger(args.results_dir) |
|
if accelerator.is_main_process: |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
logger.info(f"Experiment directory created at {args.results_dir}") |
|
json.dump(args.__dict__, open(os.path.join(args.results_dir, 'train_args.json'), 'w')) |
|
|
|
|
|
|
|
if not os.path.exists(args.model_name_or_path): |
|
cache_folder = os.getenv('HF_HUB_CACHE') |
|
args.model_name_or_path = snapshot_download(repo_id=args.model_name_or_path, |
|
cache_dir=cache_folder, |
|
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) |
|
logger.info(f"Downloaded model to {args.model_name_or_path}") |
|
model = OmniGen.from_pretrained(args.model_name_or_path) |
|
model.llm.config.use_cache = False |
|
model.llm.gradient_checkpointing_enable() |
|
model = model.to(device) |
|
|
|
if args.vae_path is None: |
|
print(args.model_name_or_path) |
|
vae_path = os.path.join(args.model_name_or_path, "vae") |
|
if os.path.exists(vae_path): |
|
vae = AutoencoderKL.from_pretrained(vae_path).to(device) |
|
else: |
|
logger.info("No VAE found in model, downloading stabilityai/sdxl-vae from HF") |
|
logger.info("If you have VAE in local folder, please specify the path with --vae_path") |
|
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) |
|
else: |
|
vae = AutoencoderKL.from_pretrained(args.vae_path).to(device) |
|
|
|
weight_dtype = torch.float32 |
|
if accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
vae.to(dtype=torch.float32) |
|
model.to(weight_dtype) |
|
|
|
processor = OmniGenProcessor.from_pretrained(args.model_name_or_path) |
|
|
|
requires_grad(vae, False) |
|
if args.use_lora: |
|
if accelerator.distributed_type == DistributedType.FSDP: |
|
raise NotImplementedError("FSDP does not support LoRA") |
|
requires_grad(model, False) |
|
transformer_lora_config = LoraConfig( |
|
r=args.lora_rank, |
|
lora_alpha=args.lora_rank, |
|
init_lora_weights="gaussian", |
|
target_modules=["qkv_proj", "o_proj"], |
|
) |
|
model.llm.enable_input_require_grads() |
|
model = get_peft_model(model, transformer_lora_config) |
|
model.to(weight_dtype) |
|
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) |
|
for n,p in model.named_parameters(): |
|
print(n, p.requires_grad) |
|
opt = torch.optim.AdamW(transformer_lora_parameters, lr=args.lr, weight_decay=args.adam_weight_decay) |
|
else: |
|
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.adam_weight_decay) |
|
|
|
ema = None |
|
if args.use_ema: |
|
ema = deepcopy(model).to(device) |
|
requires_grad(ema, False) |
|
|
|
|
|
|
|
crop_func = crop_arr |
|
if not args.keep_raw_resolution: |
|
crop_func = center_crop_arr |
|
image_transform = transforms.Compose([ |
|
transforms.Lambda(lambda pil_image: crop_func(pil_image, args.max_image_size)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
|
]) |
|
|
|
dataset = DatasetFromJson(json_file=args.json_file, |
|
image_path=args.image_path, |
|
processer=processor, |
|
image_transform=image_transform, |
|
max_input_length_limit=args.max_input_length_limit, |
|
condition_dropout_prob=args.condition_dropout_prob, |
|
keep_raw_resolution=args.keep_raw_resolution |
|
) |
|
collate_fn = TrainDataCollator(pad_token_id=processor.text_tokenizer.eos_token_id, hidden_size=model.llm.config.hidden_size, keep_raw_resolution=args.keep_raw_resolution) |
|
|
|
loader = DataLoader( |
|
dataset, |
|
collate_fn=collate_fn, |
|
batch_size=args.batch_size_per_device, |
|
shuffle=True, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=True, |
|
prefetch_factor=2, |
|
) |
|
|
|
if accelerator.is_main_process: |
|
logger.info(f"Dataset contains {len(dataset):,}") |
|
|
|
num_update_steps_per_epoch = math.ceil(len(loader) / args.gradient_accumulation_steps) |
|
max_train_steps = args.epochs * num_update_steps_per_epoch |
|
lr_scheduler = get_scheduler( |
|
args.lr_scheduler, |
|
optimizer=opt, |
|
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
|
num_training_steps=max_train_steps * args.gradient_accumulation_steps, |
|
) |
|
|
|
|
|
model.train() |
|
|
|
if ema is not None: |
|
update_ema(ema, model, decay=0) |
|
ema.eval() |
|
|
|
|
|
if ema is not None: |
|
model, ema = accelerator.prepare(model, ema) |
|
else: |
|
model = accelerator.prepare(model) |
|
|
|
opt, loader, lr_scheduler = accelerator.prepare(opt, loader, lr_scheduler) |
|
|
|
|
|
|
|
train_steps, log_steps = 0, 0 |
|
running_loss = 0 |
|
start_time = time() |
|
|
|
if accelerator.is_main_process: |
|
logger.info(f"Training for {args.epochs} epochs...") |
|
for epoch in range(args.epochs): |
|
if accelerator.is_main_process: |
|
logger.info(f"Beginning epoch {epoch}...") |
|
|
|
for data in loader: |
|
with accelerator.accumulate(model): |
|
with torch.no_grad(): |
|
output_images = data['output_images'] |
|
input_pixel_values = data['input_pixel_values'] |
|
if isinstance(output_images, list): |
|
output_images = vae_encode_list(vae, output_images, weight_dtype) |
|
if input_pixel_values is not None: |
|
input_pixel_values = vae_encode_list(vae, input_pixel_values, weight_dtype) |
|
else: |
|
output_images = vae_encode(vae, output_images, weight_dtype) |
|
if input_pixel_values is not None: |
|
input_pixel_values = vae_encode(vae, input_pixel_values, weight_dtype) |
|
|
|
|
|
model_kwargs = dict(input_ids=data['input_ids'], input_img_latents=input_pixel_values, input_image_sizes=data['input_image_sizes'], attention_mask=data['attention_mask'], position_ids=data['position_ids'], padding_latent=data['padding_images'], past_key_values=None, return_past_key_values=False) |
|
|
|
loss_dict = training_losses(model, output_images, model_kwargs) |
|
loss = loss_dict["loss"].mean() |
|
|
|
running_loss += loss.item() |
|
accelerator.backward(loss) |
|
if args.max_grad_norm is not None and accelerator.sync_gradients: |
|
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
|
opt.step() |
|
lr_scheduler.step() |
|
opt.zero_grad() |
|
|
|
log_steps += 1 |
|
train_steps += 1 |
|
|
|
accelerator.log({"training_loss": loss.item()}, step=train_steps) |
|
if train_steps % args.gradient_accumulation_steps == 0: |
|
if accelerator.sync_gradients and ema is not None: |
|
update_ema(ema, model) |
|
|
|
if train_steps % (args.log_every * args.gradient_accumulation_steps) == 0 and train_steps > 0: |
|
torch.cuda.synchronize() |
|
end_time = time() |
|
steps_per_sec = log_steps / args.gradient_accumulation_steps / (end_time - start_time) |
|
|
|
avg_loss = torch.tensor(running_loss / log_steps, device=device) |
|
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
|
avg_loss = avg_loss.item() / accelerator.num_processes |
|
|
|
if accelerator.is_main_process: |
|
cur_lr = opt.param_groups[0]["lr"] |
|
logger.info(f"(step={int(train_steps/args.gradient_accumulation_steps):07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Epoch: {train_steps/len(loader)}, LR: {cur_lr}") |
|
|
|
|
|
running_loss = 0 |
|
log_steps = 0 |
|
start_time = time() |
|
|
|
|
|
if train_steps % (args.ckpt_every * args.gradient_accumulation_steps) == 0 and train_steps > 0: |
|
if accelerator.distributed_type == DistributedType.FSDP: |
|
state_dict = accelerator.get_state_dict(model) |
|
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None |
|
else: |
|
if not args.use_lora: |
|
state_dict = model.module.state_dict() |
|
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None |
|
|
|
if accelerator.is_main_process: |
|
if args.use_lora: |
|
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/" |
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
|
model.module.save_pretrained(checkpoint_path) |
|
else: |
|
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/" |
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
torch.save(state_dict, os.path.join(checkpoint_path, "model.pt")) |
|
processor.text_tokenizer.save_pretrained(checkpoint_path) |
|
model.llm.config.save_pretrained(checkpoint_path) |
|
if ema_state_dict is not None: |
|
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}_ema" |
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
torch.save(state_dict, os.path.join(checkpoint_path, "model.pt")) |
|
processor.text_tokenizer.save_pretrained(checkpoint_path) |
|
model.llm.config.save_pretrained(checkpoint_path) |
|
logger.info(f"Saved checkpoint to {checkpoint_path}") |
|
|
|
dist.barrier() |
|
accelerator.end_training() |
|
model.eval() |
|
|
|
if accelerator.is_main_process: |
|
logger.info("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--results_dir", type=str, default="results") |
|
parser.add_argument("--model_name_or_path", type=str, default="OmniGen") |
|
parser.add_argument("--json_file", type=str) |
|
parser.add_argument("--image_path", type=str, default=None) |
|
parser.add_argument("--epochs", type=int, default=1400) |
|
parser.add_argument("--batch_size_per_device", type=int, default=1) |
|
parser.add_argument("--vae_path", type=str, default=None) |
|
parser.add_argument("--num_workers", type=int, default=4) |
|
parser.add_argument("--log_every", type=int, default=100) |
|
parser.add_argument("--ckpt_every", type=int, default=20000) |
|
parser.add_argument("--max_grad_norm", type=float, default=1.0) |
|
parser.add_argument("--lr", type=float, default=1e-4) |
|
parser.add_argument("--max_input_length_limit", type=int, default=1024) |
|
parser.add_argument("--condition_dropout_prob", type=float, default=0.1) |
|
parser.add_argument("--adam_weight_decay", type=float, default=0.0) |
|
parser.add_argument( |
|
"--keep_raw_resolution", |
|
action="store_true", |
|
help="multiple_resolutions", |
|
) |
|
parser.add_argument("--max_image_size", type=int, default=1344) |
|
|
|
parser.add_argument( |
|
"--use_lora", |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--lora_rank", |
|
type=int, |
|
default=8 |
|
) |
|
|
|
parser.add_argument( |
|
"--use_ema", |
|
action="store_true", |
|
help="Whether or not to use ema.", |
|
) |
|
parser.add_argument( |
|
"--lr_scheduler", |
|
type=str, |
|
default="constant", |
|
help=( |
|
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
|
' "constant", "constant_with_warmup"]' |
|
), |
|
) |
|
parser.add_argument( |
|
"--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler." |
|
) |
|
parser.add_argument( |
|
"--report_to", |
|
type=str, |
|
default="tensorboard", |
|
help=( |
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' |
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' |
|
), |
|
) |
|
parser.add_argument( |
|
"--mixed_precision", |
|
type=str, |
|
default="bf16", |
|
choices=["no", "fp16", "bf16"], |
|
help=( |
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" |
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." |
|
), |
|
) |
|
parser.add_argument( |
|
"--gradient_accumulation_steps", |
|
type=int, |
|
default=1, |
|
help="Number of updates steps to accumulate before performing a backward/update pass.", |
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
assert args.max_image_size % 16 == 0, "Image size must be divisible by 16." |
|
|
|
main(args) |
|
|
|
|
|
|