|
import os |
|
import math |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from torch.utils.data import DataLoader |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler |
|
from accelerate import Accelerator |
|
from datasets import load_from_disk |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import wandb |
|
import random |
|
import gc |
|
from accelerate.state import DistributedType |
|
from torch.distributed import broadcast_object_list |
|
|
|
|
|
|
|
|
|
save_path = "datasets/siski384" |
|
batch_size = 5 |
|
base_learning_rate = 8e-5 |
|
min_learning_rate = 2e-5 |
|
num_epochs = 36 |
|
project = "sdxs" |
|
use_wandb = True |
|
limit = 0 |
|
checkpoint_file = "full_checkpoint.pt" |
|
reset_lroptim = True |
|
saveckpt = False |
|
save_model = True |
|
|
|
|
|
n_diffusion_steps = 50 |
|
samples_to_generate = 6 |
|
|
|
|
|
generated_folder = "samples" |
|
checkpoints_folder = "" |
|
os.makedirs(generated_folder, exist_ok=True) |
|
|
|
|
|
seed = 42 |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
print("init") |
|
|
|
dtype = torch.bfloat16 |
|
accelerator = Accelerator(mixed_precision="bf16") |
|
device = accelerator.device |
|
gen = torch.Generator(device=device) |
|
gen.manual_seed(42) |
|
|
|
|
|
if use_wandb and accelerator.is_main_process: |
|
wandb.init(project=project, config={ |
|
"batch_size": batch_size, |
|
"base_learning_rate": base_learning_rate, |
|
"num_epochs": num_epochs, |
|
"n_diffusion_steps": n_diffusion_steps, |
|
"samples_to_generate": samples_to_generate, |
|
"dtype": str(dtype) |
|
}) |
|
|
|
|
|
if limit > 0: |
|
dataset = load_from_disk(save_path).select(range(limit)) |
|
else: |
|
dataset = load_from_disk(save_path) |
|
|
|
def collate_fn(batch): |
|
|
|
latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device) |
|
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device) |
|
return latents, embeddings |
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) |
|
print("Total samples",len(dataloader)) |
|
dataloader = accelerator.prepare(dataloader) |
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype) |
|
|
|
|
|
scheduler = DDPMScheduler( |
|
num_train_timesteps=1000, |
|
prediction_type="v_prediction", |
|
rescale_betas_zero_snr=True, |
|
timestep_spacing="leading", |
|
steps_offset=1 |
|
) |
|
|
|
|
|
start_epoch = 0 |
|
global_step = 0 |
|
|
|
|
|
total_training_steps = (len(dataloader) * num_epochs) |
|
|
|
world_size = accelerator.state.num_processes |
|
print(f"World Size: {world_size}") |
|
|
|
|
|
latest_checkpoint = os.path.join(checkpoints_folder, project) |
|
if os.path.isdir(latest_checkpoint): |
|
print("Загружаем UNet из чекпоинта:", latest_checkpoint) |
|
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype) |
|
|
|
|
|
|
|
from optimi import AdamW |
|
optimizer = torch.optim.AdamW( |
|
unet.parameters(), |
|
lr=base_learning_rate, |
|
betas=(0.9, 0.999), |
|
weight_decay=1e-6, |
|
eps=1e-8 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True): |
|
|
|
if not use_decay: |
|
return base_lr |
|
|
|
|
|
x = step / max_steps |
|
if x < 0.1: |
|
|
|
return min_lr + (base_lr - min_lr) * (x / 0.1) |
|
else: |
|
|
|
decay_ratio = (x - 0.1) / (1 - 0.1) |
|
return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio)) |
|
|
|
|
|
use_lr_decay = True |
|
|
|
def custom_lr_lambda(step): |
|
return lr_schedule(step, total_training_steps*world_size, |
|
base_learning_rate, min_learning_rate, |
|
use_lr_decay) / base_learning_rate |
|
|
|
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda) |
|
|
|
|
|
unet, optimizer,lr_scheduler = accelerator.prepare(unet, optimizer,lr_scheduler) |
|
|
|
|
|
checkpoint_path = os.path.join(checkpoints_folder, checkpoint_file) |
|
if os.path.isfile(checkpoint_path) and accelerator.is_main_process: |
|
print(f"Загружаем полный чекпоинт: {checkpoint_path}") |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
else: |
|
checkpoint = None |
|
|
|
|
|
if accelerator.state.distributed_type != DistributedType.NO: |
|
checkpoint_list = [checkpoint] |
|
broadcast_object_list(checkpoint_list, src=0) |
|
checkpoint = checkpoint_list[0] |
|
|
|
if checkpoint is not None: |
|
|
|
accelerator.unwrap_model(unet).load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_epoch = checkpoint['epoch'] + 1 |
|
global_step = checkpoint['global_step'] |
|
print(f"Продолжаем с эпохи {start_epoch}, шага {global_step}") |
|
|
|
|
|
|
|
|
|
sample_indices = np.random.choice(len(dataset), samples_to_generate, replace=False) |
|
sample_data = [dataset[int(i)] for i in sample_indices] |
|
|
|
def get_sample_inputs(sample_data): |
|
|
|
return collate_fn(sample_data) |
|
|
|
@torch.no_grad() |
|
def generate_and_save_samples(step: int): |
|
""" |
|
Перемещает VAE на device, генерирует семплы с заданным числом diffusion шагов, |
|
декодирует их и отправляет на WandB. |
|
После генерации VAE возвращается на CPU. |
|
""" |
|
try: |
|
original_model = accelerator.unwrap_model(unet) |
|
|
|
vae.to(accelerator.device, dtype=dtype) |
|
|
|
|
|
scheduler.set_timesteps(n_diffusion_steps) |
|
|
|
|
|
sample_latents, sample_text_embeddings = get_sample_inputs(sample_data) |
|
sample_latents = sample_latents.to(dtype) |
|
|
|
|
|
sample_latents = torch.randn( |
|
sample_latents.shape, |
|
generator=gen, |
|
device=sample_latents.device, |
|
dtype=sample_latents.dtype |
|
) |
|
|
|
|
|
guidance_scale = 3.0 |
|
|
|
|
|
with torch.no_grad(): |
|
for t in scheduler.timesteps: |
|
latent_model_input = scheduler.scale_model_input(sample_latents, t) |
|
noise_pred = original_model(latent_model_input, t, sample_text_embeddings).sample |
|
|
|
sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample |
|
|
|
|
|
latent = (sample_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor |
|
latent = latent.to(accelerator.device, dtype=dtype) |
|
decoded = vae.decode(latent).sample |
|
|
|
|
|
generated_images = [] |
|
for img_idx, img_tensor in enumerate(decoded): |
|
img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0) |
|
pil_img = Image.fromarray((img * 255).astype("uint8")) |
|
generated_images.append(pil_img) |
|
save_path = f"{generated_folder}/{project}_{img_idx}.jpg" |
|
pil_img.save(save_path, "JPEG", quality=95) |
|
|
|
|
|
if use_wandb and accelerator.is_main_process: |
|
wandb_images = [wandb.Image(img, caption=f"Sample {i}") for i, img in enumerate(generated_images)] |
|
wandb.log({"generated_images": wandb_images, "global_step": step}) |
|
|
|
finally: |
|
|
|
vae.to("cpu") |
|
if original_model is not None: |
|
del original_model |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
if accelerator.is_main_process: |
|
print("Генерация сэмплов до старта обучения...") |
|
generate_and_save_samples(step=0) |
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
print(f"Total steps per GPU: {total_training_steps}") |
|
print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}") |
|
|
|
epoch_loss_points = [] |
|
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") |
|
|
|
|
|
steps_per_epoch = len(dataloader) |
|
sample_interval = max(1, steps_per_epoch // 4) |
|
|
|
def save_ckpt(): |
|
if saveckpt: |
|
|
|
checkpoint_path = os.path.join(checkpoints_folder, checkpoint_file) |
|
torch.save({ |
|
'model_state_dict': accelerator.unwrap_model(unet).state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
|
'epoch': epoch, |
|
'global_step': global_step, |
|
}, checkpoint_path) |
|
print(f"Полный чекпоинт сохранён: {checkpoint_path}") |
|
|
|
|
|
for epoch in range(start_epoch, start_epoch + num_epochs): |
|
batch_losses = [] |
|
unet.train() |
|
|
|
for step, (latents, embeddings) in enumerate(dataloader): |
|
with accelerator.accumulate(unet): |
|
|
|
noise = torch.randn_like(latents) |
|
timesteps = torch.randint( |
|
1, |
|
scheduler.config.num_train_timesteps, |
|
(latents.shape[0],), |
|
device=device |
|
).long() |
|
|
|
|
|
noisy_latents = scheduler.add_noise(latents, noise, timesteps) |
|
|
|
|
|
noise_pred = unet(noisy_latents, timesteps, embeddings).sample |
|
|
|
|
|
target = scheduler.get_velocity(latents, noise, timesteps) |
|
|
|
|
|
|
|
|
|
loss = torch.nn.functional.mse_loss( |
|
noise_pred.float(), |
|
target.float() |
|
) |
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
|
|
|
grad_norm = accelerator.clip_grad_norm_(unet.parameters(), 1.0) |
|
|
|
|
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
progress_bar.update(1) |
|
|
|
|
|
if accelerator.is_main_process: |
|
current_lr = lr_scheduler.get_last_lr()[0] |
|
|
|
batch_losses.append(loss.detach().item()) |
|
|
|
|
|
if global_step % 2 == 0 and use_wandb: |
|
wandb.log({ |
|
"loss": loss.detach().item(), |
|
"learning_rate": current_lr, |
|
"epoch": epoch, |
|
"grad_norm": grad_norm.item(), |
|
"epoch": epoch, |
|
"global_step": global_step |
|
}) |
|
|
|
|
|
if global_step % sample_interval == 0 and accelerator.is_main_process: |
|
generate_and_save_samples(global_step) |
|
|
|
|
|
avg_loss = np.mean(batch_losses[-sample_interval:]) |
|
|
|
if use_wandb: |
|
wandb.log({"intermediate_loss": avg_loss}) |
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|
if accelerator.is_main_process: |
|
save_ckpt() |
|
|
|
|
|
if save_model: |
|
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}")) |
|
avg_epoch_loss = np.mean(batch_losses) |
|
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") |
|
if use_wandb: |
|
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1}) |
|
|
|
|
|
if accelerator.is_main_process: |
|
print("Обучение завершено! Сохраняем финальную модель...") |
|
|
|
if save_model: |
|
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}")) |
|
print("Готово!") |