|
|
|
|
|
import os |
|
|
import math |
|
|
import torch |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from torch.utils.data import DataLoader, Sampler |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
from collections import defaultdict |
|
|
from diffusers import UNet2DConditionModel, AutoencoderKL |
|
|
from accelerate import Accelerator |
|
|
from datasets import load_from_disk |
|
|
from tqdm import tqdm |
|
|
from PIL import Image, ImageOps |
|
|
import wandb |
|
|
import random |
|
|
import gc |
|
|
from accelerate.state import DistributedType |
|
|
from torch.distributed import broadcast_object_list |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
|
from datetime import datetime |
|
|
import bitsandbytes as bnb |
|
|
import torch.nn.functional as F |
|
|
from collections import deque |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
|
ds_path = "/workspace/sdxs/datasets/768" |
|
|
project = "unet" |
|
|
batch_size = 36 |
|
|
base_learning_rate = 2.7e-5 |
|
|
min_learning_rate = 1e-5 |
|
|
num_epochs = 80 |
|
|
sample_interval_share = 5 |
|
|
max_length = 192 |
|
|
use_wandb = True |
|
|
use_comet_ml = False |
|
|
save_model = True |
|
|
use_decay = True |
|
|
fbp = False |
|
|
optimizer_type = "adam8bit" |
|
|
torch_compile = False |
|
|
unet_gradient = True |
|
|
fixed_seed = False |
|
|
shuffle = True |
|
|
comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" |
|
|
comet_ml_workspace = "recoilme" |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
dtype = torch.float32 |
|
|
save_barrier = 1.01 |
|
|
warmup_percent = 0.01 |
|
|
percentile_clipping = 96 |
|
|
betta2 = 0.999 |
|
|
eps = 1e-7 |
|
|
clip_grad_norm = 1.0 |
|
|
limit = 0 |
|
|
checkpoints_folder = "" |
|
|
mixed_precision = "no" |
|
|
gradient_accumulation_steps = 1 |
|
|
|
|
|
accelerator = Accelerator( |
|
|
mixed_precision=mixed_precision, |
|
|
gradient_accumulation_steps=gradient_accumulation_steps |
|
|
) |
|
|
device = accelerator.device |
|
|
|
|
|
|
|
|
n_diffusion_steps = 40 |
|
|
samples_to_generate = 12 |
|
|
guidance_scale = 4 |
|
|
|
|
|
|
|
|
generated_folder = "samples" |
|
|
os.makedirs(generated_folder, exist_ok=True) |
|
|
|
|
|
|
|
|
current_date = datetime.now() |
|
|
seed = int(current_date.strftime("%Y%m%d")) |
|
|
if fixed_seed: |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
lora_name = "" |
|
|
lora_rank = 32 |
|
|
lora_alpha = 64 |
|
|
|
|
|
print("init") |
|
|
|
|
|
loss_ratios = { |
|
|
"mse": 1., |
|
|
} |
|
|
median_coeff_steps = 256 |
|
|
|
|
|
|
|
|
class MedianLossNormalizer: |
|
|
def __init__(self, desired_ratios: dict, window_steps: int): |
|
|
|
|
|
s = sum(desired_ratios.values()) |
|
|
self.ratios = {k: (v / s) for k, v in desired_ratios.items()} |
|
|
self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()} |
|
|
self.window = window_steps |
|
|
|
|
|
def update_and_total(self, losses: dict): |
|
|
""" |
|
|
losses: dict ключ->тензор (значения лоссов) |
|
|
Поведение: |
|
|
- буферим ABS(l) только для активных (ratio>0) лоссов |
|
|
- coeff = ratio / median(abs(loss)) |
|
|
- total = sum(coeff * loss) по активным лоссам |
|
|
CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление. |
|
|
""" |
|
|
|
|
|
for k, v in losses.items(): |
|
|
if k in self.buffers and self.ratios.get(k, 0) > 0: |
|
|
self.buffers[k].append(float(v.detach().abs().cpu())) |
|
|
|
|
|
meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers} |
|
|
coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios} |
|
|
|
|
|
|
|
|
total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0) |
|
|
return total, coeffs, meds |
|
|
|
|
|
|
|
|
normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
if use_wandb: |
|
|
wandb.init(project=project+lora_name, config={ |
|
|
"batch_size": batch_size, |
|
|
"base_learning_rate": base_learning_rate, |
|
|
"num_epochs": num_epochs, |
|
|
"optimizer_type": optimizer_type, |
|
|
}) |
|
|
if use_comet_ml: |
|
|
from comet_ml import Experiment |
|
|
comet_experiment = Experiment( |
|
|
api_key=comet_ml_api_key, |
|
|
project_name=project, |
|
|
workspace=comet_ml_workspace |
|
|
) |
|
|
hyper_params = { |
|
|
"batch_size": batch_size, |
|
|
"base_learning_rate": base_learning_rate, |
|
|
"num_epochs": num_epochs, |
|
|
} |
|
|
comet_experiment.log_parameters(hyper_params) |
|
|
|
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(True) |
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval() |
|
|
tokenizer = AutoTokenizer.from_pretrained("tokenizer") |
|
|
text_model = AutoModel.from_pretrained("text_encoder").to(device).eval() |
|
|
|
|
|
|
|
|
def encode_texts(texts, max_length=max_length): |
|
|
|
|
|
if texts is None: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
with torch.no_grad(): |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
for i, prompt_item in enumerate(texts): |
|
|
messages = [ |
|
|
{"role": "user", "content": prompt_item}, |
|
|
] |
|
|
prompt_item = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
|
|
|
) |
|
|
|
|
|
texts[i] = prompt_item |
|
|
|
|
|
toks = tokenizer( |
|
|
texts, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=max_length |
|
|
).to(device) |
|
|
|
|
|
outs = text_model(**toks, output_hidden_states=True, return_dict=True) |
|
|
|
|
|
|
|
|
hidden = outs.hidden_states[-2] |
|
|
|
|
|
|
|
|
attention_mask = toks["attention_mask"] |
|
|
|
|
|
|
|
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
|
batch_size = hidden.shape[0] |
|
|
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pooled_expanded = pooled.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1) |
|
|
|
|
|
return new_encoder_hidden_states, new_attention_mask |
|
|
|
|
|
shift_factor = getattr(vae.config, "shift_factor", 0.0) |
|
|
if shift_factor is None: shift_factor = 0.0 |
|
|
scaling_factor = getattr(vae.config, "scaling_factor", 1.0) |
|
|
if scaling_factor is None: scaling_factor = 1.0 |
|
|
|
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
num_train_timesteps = 1000 |
|
|
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps) |
|
|
|
|
|
class DistributedResolutionBatchSampler(Sampler): |
|
|
def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True): |
|
|
self.dataset = dataset |
|
|
self.batch_size = max(1, batch_size // num_replicas) |
|
|
self.num_replicas = num_replicas |
|
|
self.rank = rank |
|
|
self.shuffle = shuffle |
|
|
self.drop_last = drop_last |
|
|
self.epoch = 0 |
|
|
|
|
|
try: |
|
|
widths = np.array(dataset["width"]) |
|
|
heights = np.array(dataset["height"]) |
|
|
except KeyError: |
|
|
widths = np.zeros(len(dataset)) |
|
|
heights = np.zeros(len(dataset)) |
|
|
|
|
|
self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0) |
|
|
self.size_groups = {} |
|
|
for w, h in self.size_keys: |
|
|
mask = (widths == w) & (heights == h) |
|
|
self.size_groups[(w, h)] = np.where(mask)[0] |
|
|
|
|
|
self.group_num_batches = {} |
|
|
total_batches = 0 |
|
|
for size, indices in self.size_groups.items(): |
|
|
num_full_batches = len(indices) // (self.batch_size * self.num_replicas) |
|
|
self.group_num_batches[size] = num_full_batches |
|
|
total_batches += num_full_batches |
|
|
|
|
|
self.num_batches = (total_batches // self.num_replicas) * self.num_replicas |
|
|
|
|
|
def __iter__(self): |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
all_batches = [] |
|
|
rng = np.random.RandomState(self.epoch) |
|
|
|
|
|
for size, indices in self.size_groups.items(): |
|
|
indices = indices.copy() |
|
|
if self.shuffle: |
|
|
rng.shuffle(indices) |
|
|
num_full_batches = self.group_num_batches[size] |
|
|
if num_full_batches == 0: |
|
|
continue |
|
|
valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas] |
|
|
batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas) |
|
|
start_idx = self.rank * self.batch_size |
|
|
end_idx = start_idx + self.batch_size |
|
|
gpu_batches = batches[:, start_idx:end_idx] |
|
|
all_batches.extend(gpu_batches) |
|
|
|
|
|
if self.shuffle: |
|
|
rng.shuffle(all_batches) |
|
|
accelerator.wait_for_everyone() |
|
|
return iter(all_batches) |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_batches |
|
|
|
|
|
def set_epoch(self, epoch): |
|
|
self.epoch = epoch |
|
|
|
|
|
|
|
|
def get_fixed_samples_by_resolution(dataset, samples_per_group=1): |
|
|
size_groups = defaultdict(list) |
|
|
try: |
|
|
widths = dataset["width"] |
|
|
heights = dataset["height"] |
|
|
except KeyError: |
|
|
widths = [0] * len(dataset) |
|
|
heights = [0] * len(dataset) |
|
|
for i, (w, h) in enumerate(zip(widths, heights)): |
|
|
size = (w, h) |
|
|
size_groups[size].append(i) |
|
|
|
|
|
fixed_samples = {} |
|
|
for size, indices in size_groups.items(): |
|
|
n_samples = min(samples_per_group, len(indices)) |
|
|
if len(size_groups)==1: |
|
|
n_samples = samples_to_generate |
|
|
if n_samples == 0: |
|
|
continue |
|
|
sample_indices = random.sample(indices, n_samples) |
|
|
samples_data = [dataset[idx] for idx in sample_indices] |
|
|
|
|
|
latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype) |
|
|
texts = [item["text"] for item in samples_data] |
|
|
|
|
|
|
|
|
embeddings, masks = encode_texts(texts) |
|
|
|
|
|
fixed_samples[size] = (latents, embeddings, masks, texts) |
|
|
|
|
|
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") |
|
|
return fixed_samples |
|
|
|
|
|
if limit > 0: |
|
|
dataset = load_from_disk(ds_path).select(range(limit)) |
|
|
else: |
|
|
dataset = load_from_disk(ds_path) |
|
|
|
|
|
|
|
|
def collate_fn_simple(batch): |
|
|
|
|
|
latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype) |
|
|
|
|
|
|
|
|
raw_texts = [item["text"] for item in batch] |
|
|
texts = [ |
|
|
"" if t.lower().startswith("zero") |
|
|
else "" if random.random() < 0.05 |
|
|
else t[1:].lstrip() if t.startswith(".") |
|
|
else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip() |
|
|
for t in raw_texts |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
embeddings, attention_mask = encode_texts(texts) |
|
|
|
|
|
|
|
|
attention_mask = attention_mask.to(dtype=torch.int64) |
|
|
|
|
|
return latents, embeddings, attention_mask |
|
|
|
|
|
batch_sampler = DistributedResolutionBatchSampler( |
|
|
dataset=dataset, |
|
|
batch_size=batch_size, |
|
|
num_replicas=accelerator.num_processes, |
|
|
rank=accelerator.process_index, |
|
|
shuffle=shuffle |
|
|
) |
|
|
|
|
|
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple) |
|
|
print("Total samples", len(dataloader)) |
|
|
dataloader = accelerator.prepare(dataloader) |
|
|
|
|
|
start_epoch = 0 |
|
|
global_step = 0 |
|
|
total_training_steps = (len(dataloader) * num_epochs) |
|
|
world_size = accelerator.state.num_processes |
|
|
|
|
|
|
|
|
latest_checkpoint = os.path.join(checkpoints_folder, project) |
|
|
if os.path.isdir(latest_checkpoint): |
|
|
print("Загружаем UNet из чекпоинта:", latest_checkpoint) |
|
|
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype) |
|
|
if unet_gradient: |
|
|
unet.enable_gradient_checkpointing() |
|
|
unet.set_use_memory_efficient_attention_xformers(False) |
|
|
try: |
|
|
unet.set_attn_processor(AttnProcessor2_0()) |
|
|
except Exception as e: |
|
|
print(f"Ошибка при включении SDPA: {e}") |
|
|
unet.set_use_memory_efficient_attention_xformers(True) |
|
|
else: |
|
|
raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}") |
|
|
|
|
|
if lora_name: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if lora_name: |
|
|
trainable_params = [p for p in unet.parameters() if p.requires_grad] |
|
|
else: |
|
|
if fbp: |
|
|
trainable_params = list(unet.parameters()) |
|
|
|
|
|
def create_optimizer(name, params): |
|
|
if name == "adam8bit": |
|
|
return bnb.optim.AdamW8bit( |
|
|
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01, |
|
|
percentile_clipping=percentile_clipping |
|
|
) |
|
|
elif name == "adam": |
|
|
return torch.optim.AdamW( |
|
|
params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01 |
|
|
) |
|
|
elif name == "muon": |
|
|
from muon import MuonWithAuxAdam |
|
|
trainable_params = [p for p in params if p.requires_grad] |
|
|
hidden_weights = [p for p in trainable_params if p.ndim >= 2] |
|
|
hidden_gains_biases = [p for p in trainable_params if p.ndim < 2] |
|
|
|
|
|
param_groups = [ |
|
|
dict(params=hidden_weights, use_muon=True, |
|
|
lr=1e-3, weight_decay=1e-4), |
|
|
dict(params=hidden_gains_biases, use_muon=False, |
|
|
lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4), |
|
|
] |
|
|
optimizer = MuonWithAuxAdam(param_groups) |
|
|
from snooc import SnooC |
|
|
return SnooC(optimizer) |
|
|
else: |
|
|
raise ValueError(f"Unknown optimizer: {name}") |
|
|
|
|
|
if fbp: |
|
|
optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params} |
|
|
def optimizer_hook(param): |
|
|
optimizer_dict[param].step() |
|
|
optimizer_dict[param].zero_grad(set_to_none=True) |
|
|
for param in trainable_params: |
|
|
param.register_post_accumulate_grad_hook(optimizer_hook) |
|
|
unet, optimizer = accelerator.prepare(unet, optimizer_dict) |
|
|
else: |
|
|
optimizer = create_optimizer(optimizer_type, unet.parameters()) |
|
|
def lr_schedule(step): |
|
|
x = step / (total_training_steps * world_size) |
|
|
warmup = warmup_percent |
|
|
if not use_decay: |
|
|
return base_learning_rate |
|
|
if x < warmup: |
|
|
return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup) |
|
|
decay_ratio = (x - warmup) / (1 - warmup) |
|
|
return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \ |
|
|
(1 + math.cos(math.pi * decay_ratio)) |
|
|
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate) |
|
|
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) |
|
|
|
|
|
if torch_compile: |
|
|
print("compiling") |
|
|
unet = torch.compile(unet) |
|
|
print("compiling - ok") |
|
|
|
|
|
|
|
|
fixed_samples = get_fixed_samples_by_resolution(dataset) |
|
|
|
|
|
|
|
|
def get_negative_embedding(neg_prompt="", batch_size=1): |
|
|
if not neg_prompt: |
|
|
hidden_dim = 2048 |
|
|
seq_len = max_length |
|
|
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) |
|
|
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) |
|
|
return empty_emb, empty_mask |
|
|
|
|
|
uncond_emb, uncond_mask = encode_texts([neg_prompt]) |
|
|
uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1) |
|
|
uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1) |
|
|
|
|
|
return uncond_emb, uncond_mask |
|
|
|
|
|
|
|
|
uncond_emb, uncond_mask = get_negative_embedding("low quality") |
|
|
|
|
|
|
|
|
@torch.compiler.disable() |
|
|
@torch.no_grad() |
|
|
def generate_and_save_samples(fixed_samples_cpu, uncond_data, step): |
|
|
uncond_emb, uncond_mask = uncond_data |
|
|
|
|
|
original_model = None |
|
|
try: |
|
|
if not torch_compile: |
|
|
original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval() |
|
|
else: |
|
|
original_model = unet.eval() |
|
|
|
|
|
vae.to(device=device).eval() |
|
|
|
|
|
all_generated_images = [] |
|
|
all_captions = [] |
|
|
|
|
|
|
|
|
for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items(): |
|
|
width, height = size |
|
|
sample_latents = sample_latents.to(dtype=dtype, device=device) |
|
|
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device) |
|
|
sample_mask = sample_mask.to(device=device) |
|
|
|
|
|
latents = torch.randn( |
|
|
sample_latents.shape, |
|
|
device=device, |
|
|
dtype=sample_latents.dtype, |
|
|
generator=torch.Generator(device=device).manual_seed(seed) |
|
|
) |
|
|
|
|
|
scheduler.set_timesteps(n_diffusion_steps, device=device) |
|
|
|
|
|
for t in scheduler.timesteps: |
|
|
if guidance_scale != 1: |
|
|
latent_model_input = torch.cat([latents, latents], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
curr_batch_size = sample_text_embeddings.shape[0] |
|
|
seq_len = sample_text_embeddings.shape[1] |
|
|
hidden_dim = sample_text_embeddings.shape[2] |
|
|
|
|
|
neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1) |
|
|
text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0) |
|
|
|
|
|
|
|
|
neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1) |
|
|
attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0) |
|
|
|
|
|
else: |
|
|
latent_model_input = latents |
|
|
text_embeddings_batch = sample_text_embeddings |
|
|
attention_mask_batch = sample_mask |
|
|
|
|
|
|
|
|
model_out = original_model( |
|
|
latent_model_input, |
|
|
t, |
|
|
encoder_hidden_states=text_embeddings_batch, |
|
|
encoder_attention_mask=attention_mask_batch, |
|
|
) |
|
|
flow = getattr(model_out, "sample", model_out) |
|
|
|
|
|
if guidance_scale != 1: |
|
|
flow_uncond, flow_cond = flow.chunk(2) |
|
|
flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond) |
|
|
|
|
|
latents = scheduler.step(flow, t, latents).prev_sample |
|
|
|
|
|
current_latents = latents |
|
|
|
|
|
latent_for_vae = current_latents.detach() / scaling_factor + shift_factor |
|
|
decoded = vae.decode(latent_for_vae.to(torch.float32)).sample |
|
|
decoded_fp32 = decoded.to(torch.float32) |
|
|
|
|
|
for img_idx, img_tensor in enumerate(decoded_fp32): |
|
|
img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy() |
|
|
img = img.transpose(1, 2, 0) |
|
|
|
|
|
if np.isnan(img).any(): |
|
|
print("NaNs found, saving stopped! Step:", step) |
|
|
pil_img = Image.fromarray((img * 255).astype("uint8")) |
|
|
|
|
|
max_w_overall = max(s[0] for s in fixed_samples_cpu.keys()) |
|
|
max_h_overall = max(s[1] for s in fixed_samples_cpu.keys()) |
|
|
max_w_overall = max(255, max_w_overall) |
|
|
max_h_overall = max(255, max_h_overall) |
|
|
|
|
|
padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white') |
|
|
all_generated_images.append(padded_img) |
|
|
|
|
|
caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else "" |
|
|
all_captions.append(caption_text) |
|
|
|
|
|
sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" |
|
|
pil_img.save(sample_path, "JPEG", quality=96) |
|
|
|
|
|
if use_wandb and accelerator.is_main_process: |
|
|
wandb_images = [ |
|
|
wandb.Image(img, caption=f"{all_captions[i]}") |
|
|
for i, img in enumerate(all_generated_images) |
|
|
] |
|
|
wandb.log({"generated_images": wandb_images}) |
|
|
if use_comet_ml and accelerator.is_main_process: |
|
|
for i, img in enumerate(all_generated_images): |
|
|
comet_experiment.log_image( |
|
|
image_data=img, |
|
|
name=f"step_{step}_img_{i}", |
|
|
step=step, |
|
|
metadata={"caption": all_captions[i]} |
|
|
) |
|
|
finally: |
|
|
vae.to("cpu") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
if save_model: |
|
|
print("Генерация сэмплов до старта обучения...") |
|
|
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0) |
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
def save_checkpoint(unet, variant=""): |
|
|
if accelerator.is_main_process: |
|
|
if lora_name: |
|
|
save_lora_checkpoint(unet) |
|
|
else: |
|
|
model_to_save = None |
|
|
if not torch_compile: |
|
|
model_to_save = accelerator.unwrap_model(unet) |
|
|
else: |
|
|
model_to_save = unet |
|
|
|
|
|
if variant != "": |
|
|
model_to_save.to(dtype=torch.float16).save_pretrained( |
|
|
os.path.join(checkpoints_folder, f"{project}"), variant=variant |
|
|
) |
|
|
else: |
|
|
model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}")) |
|
|
|
|
|
unet = unet.to(dtype=dtype) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
print(f"Total steps per GPU: {total_training_steps}") |
|
|
|
|
|
epoch_loss_points = [] |
|
|
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") |
|
|
|
|
|
steps_per_epoch = len(dataloader) |
|
|
sample_interval = max(1, steps_per_epoch // sample_interval_share) |
|
|
min_loss = 2. |
|
|
|
|
|
for epoch in range(start_epoch, start_epoch + num_epochs): |
|
|
batch_losses = [] |
|
|
batch_grads = [] |
|
|
batch_sampler.set_epoch(epoch) |
|
|
accelerator.wait_for_everyone() |
|
|
unet.train() |
|
|
|
|
|
for step, (latents, embeddings, attention_mask) in enumerate(dataloader): |
|
|
with accelerator.accumulate(unet): |
|
|
if save_model == False and step == 5 : |
|
|
used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
print(f"Шаг {step}: {used_gb:.2f} GB") |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents, dtype=latents.dtype) |
|
|
|
|
|
t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype) |
|
|
|
|
|
noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise |
|
|
|
|
|
timesteps = (t * scheduler.config.num_train_timesteps).long() |
|
|
|
|
|
|
|
|
model_pred = unet( |
|
|
noisy_latents, |
|
|
timesteps, |
|
|
encoder_hidden_states=embeddings, |
|
|
encoder_attention_mask=attention_mask |
|
|
).sample |
|
|
|
|
|
target = noise - latents |
|
|
mse_loss = F.mse_loss(model_pred.float(), target.float()) |
|
|
batch_losses.append(mse_loss.detach().item()) |
|
|
|
|
|
if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
losses_dict = {} |
|
|
losses_dict["mse"] = mse_loss |
|
|
|
|
|
|
|
|
abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()} |
|
|
total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm) |
|
|
|
|
|
if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
accelerator.backward(total_loss) |
|
|
|
|
|
if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
grad = 0.0 |
|
|
if not fbp: |
|
|
if accelerator.sync_gradients: |
|
|
|
|
|
grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm) |
|
|
grad = float(grad_val) |
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
global_step += 1 |
|
|
progress_bar.update(1) |
|
|
if accelerator.is_main_process: |
|
|
if fbp: |
|
|
current_lr = base_learning_rate |
|
|
else: |
|
|
current_lr = lr_scheduler.get_last_lr()[0] |
|
|
batch_grads.append(grad) |
|
|
|
|
|
log_data = {} |
|
|
log_data["loss"] = mse_loss.detach().item() |
|
|
log_data["lr"] = current_lr |
|
|
log_data["grad"] = grad |
|
|
log_data["loss_total"] = float(total_loss.item()) |
|
|
for k, c in coeffs.items(): |
|
|
log_data[f"coeff_{k}"] = float(c) |
|
|
if accelerator.sync_gradients: |
|
|
if use_wandb: |
|
|
wandb.log(log_data, step=global_step) |
|
|
if use_comet_ml: |
|
|
comet_experiment.log_metrics(log_data, step=global_step) |
|
|
|
|
|
if global_step % sample_interval == 0: |
|
|
|
|
|
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) |
|
|
last_n = sample_interval |
|
|
|
|
|
if save_model: |
|
|
has_losses = len(batch_losses) > 0 |
|
|
avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0 |
|
|
last_loss = batch_losses[-1] if has_losses else 0.0 |
|
|
max_loss = max(avg_sample_loss, last_loss) |
|
|
should_save = max_loss < min_loss * save_barrier |
|
|
print( |
|
|
f"Saving: {should_save} | Max: {max_loss:.4f} | " |
|
|
f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}" |
|
|
) |
|
|
|
|
|
if should_save: |
|
|
min_loss = max_loss |
|
|
save_checkpoint(unet) |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0 |
|
|
avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0 |
|
|
|
|
|
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") |
|
|
log_data_ep = { |
|
|
"epoch_loss": avg_epoch_loss, |
|
|
"epoch_grad": avg_epoch_grad, |
|
|
"epoch": epoch + 1, |
|
|
} |
|
|
if use_wandb: |
|
|
wandb.log(log_data_ep) |
|
|
if use_comet_ml: |
|
|
comet_experiment.log_metrics(log_data_ep) |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
print("Обучение завершено! Сохраняем финальную модель...") |
|
|
if save_model: |
|
|
save_checkpoint(unet,"fp16") |
|
|
if use_comet_ml: |
|
|
comet_experiment.end() |
|
|
accelerator.free_memory() |
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.destroy_process_group() |
|
|
|
|
|
print("Готово!") |