Diffusers
Safetensors
sdxs / train_siski.py
recoilme's picture
ep8
7807721
import os
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from accelerate import Accelerator
from datasets import load_from_disk
from tqdm import tqdm
from PIL import Image
import wandb
import random
import gc
from accelerate.state import DistributedType
from torch.distributed import broadcast_object_list
#from adamw_bfloat16 import LR, AdamW_BF16
# --------------------------- Параметры ---------------------------
save_path = "datasets/siski384"#"datasets/siski64" #"datasets/mnist"
batch_size = 5
base_learning_rate = 8e-5
min_learning_rate = 2e-5
num_epochs = 36 #18
project = "sdxs"
use_wandb = True
limit = 0
checkpoint_file = "full_checkpoint.pt"
reset_lroptim = True
saveckpt = False
save_model = True
# Параметры для диффузии
n_diffusion_steps = 50 # Увеличиваем число шагов для лучшего качества
samples_to_generate = 6
# Папки для сохранения результатов
generated_folder = "samples"
checkpoints_folder = ""
os.makedirs(generated_folder, exist_ok=True)
# Настройка seed для воспроизводимости
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
print("init")
# --------------------------- Инициализация Accelerator ---------------------------
dtype = torch.bfloat16
accelerator = Accelerator(mixed_precision="bf16")
device = accelerator.device
gen = torch.Generator(device=device)
gen.manual_seed(42)
# --------------------------- Инициализация WandB ---------------------------
if use_wandb and accelerator.is_main_process:
wandb.init(project=project, config={
"batch_size": batch_size,
"base_learning_rate": base_learning_rate,
"num_epochs": num_epochs,
"n_diffusion_steps": n_diffusion_steps,
"samples_to_generate": samples_to_generate,
"dtype": str(dtype)
})
# --------------------------- Загрузка датасета ---------------------------
if limit > 0:
dataset = load_from_disk(save_path).select(range(limit))
else:
dataset = load_from_disk(save_path)
def collate_fn(batch):
# Преобразуем список в тензоры и перемещаем на девайс
latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
return latents, embeddings
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
print("Total samples",len(dataloader))
dataloader = accelerator.prepare(dataloader)
# --------------------------- Загрузка моделей ---------------------------
# VAE загружается на CPU для экономии GPU-памяти
vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype)
# DDPMScheduler с V_Prediction и Zero-SNR
scheduler = DDPMScheduler(
num_train_timesteps=1000, # Полный график шагов для обучения
prediction_type="v_prediction", # V-Prediction
rescale_betas_zero_snr=True, # Включение Zero-SNR
timestep_spacing="leading", # Добавляем улучшенное распределение шагов
steps_offset=1 # Избегаем проблем с нулевым timestep
)
# Инициализация переменных для возобновления обучения
start_epoch = 0
global_step = 0
# Расчёт общего количества шагов
total_training_steps = (len(dataloader) * num_epochs)
# Get the world size
world_size = accelerator.state.num_processes
print(f"World Size: {world_size}")
# Опция загрузки модели из последнего чекпоинта (если существует)
latest_checkpoint = os.path.join(checkpoints_folder, project)
if os.path.isdir(latest_checkpoint):
print("Загружаем UNet из чекпоинта:", latest_checkpoint)
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype)
# --------------------------- Оптимизатор и кастомный LR scheduler ---------------------------
# Улучшенный AdamW с правильными параметрами
from optimi import AdamW
optimizer = torch.optim.AdamW(
unet.parameters(),
lr=base_learning_rate,
betas=(0.9, 0.999),
weight_decay=1e-6, # Безопасное значение для SD
eps=1e-8
)
#optimizer = AdamW_BF16(unet.parameters(), lr_function=LR(lr=base_learning_rate, preheat_steps=5000, decay_power=-0.25))
# Улучшенный планировщик с фиксированной скоростью обучения для случаев возобновления
# и возможностью плавного затухания
def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True):
# Если не используем затухание, возвращаем базовый LR
if not use_decay:
return base_lr
# Иначе используем линейный прогрев и косинусное затухание
x = step / max_steps
if x < 0.1:
# Линейный прогрев до 10% шагов
return min_lr + (base_lr - min_lr) * (x / 0.1)
else:
# Косинусное затухание
decay_ratio = (x - 0.1) / (1 - 0.1)
return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
# Флаг для контроля использования затухания LR, вы можете менять его при возобновлении обучения
use_lr_decay = True # Устанавливайте в False, если хотите отключить затухание
def custom_lr_lambda(step):
return lr_schedule(step, total_training_steps*world_size,
base_learning_rate, min_learning_rate,
use_lr_decay) / base_learning_rate
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
# Подготовка через Accelerator
unet, optimizer,lr_scheduler = accelerator.prepare(unet, optimizer,lr_scheduler)
# Загрузка полного чекпоинта если существует
checkpoint_path = os.path.join(checkpoints_folder, checkpoint_file)
if os.path.isfile(checkpoint_path) and accelerator.is_main_process:
print(f"Загружаем полный чекпоинт: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
else:
checkpoint = None
# Если задействовано распределённое обучение, рассылаем объект чекпоинта всем процессам.
if accelerator.state.distributed_type != DistributedType.NO:
checkpoint_list = [checkpoint]
broadcast_object_list(checkpoint_list, src=0)
checkpoint = checkpoint_list[0]
if checkpoint is not None:
# Загружаем состояния
accelerator.unwrap_model(unet).load_state_dict(checkpoint['model_state_dict'])
#if not reset_lroptim:
# Пересоздаем оптимизатор с текущим learning rate
# optimizer.param_groups[0]['lr'] = base_learning_rate
# optimizer = torch.optim.AdamW(
# unet.parameters(),
# lr=base_learning_rate, # Можно изменить при возобновлении
# betas=(0.9, 0.999),
# weight_decay=1e-6,
# eps=1e-8
# )
# print(f"optimizer learning rate: {optimizer.param_groups[0]['lr']}")
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
global_step = checkpoint['global_step']
print(f"Продолжаем с эпохи {start_epoch}, шага {global_step}")
# --------------------------- Фиксированные семплы для генерации ---------------------------
# Выбираем фиксированные индексы (приводим к int)
sample_indices = np.random.choice(len(dataset), samples_to_generate, replace=False)
sample_data = [dataset[int(i)] for i in sample_indices]
def get_sample_inputs(sample_data):
# Получаем latents и текстовые эмбеддинги для выбранных примеров (collate_fn переносит на device)
return collate_fn(sample_data)
@torch.no_grad()
def generate_and_save_samples(step: int):
"""
Перемещает VAE на device, генерирует семплы с заданным числом diffusion шагов,
декодирует их и отправляет на WandB.
После генерации VAE возвращается на CPU.
"""
try:
original_model = accelerator.unwrap_model(unet)
# Перемещаем VAE на device для семплирования
vae.to(accelerator.device, dtype=dtype)
# Устанавливаем количество diffusion шагов
scheduler.set_timesteps(n_diffusion_steps)
# Получаем фиксированные данные
sample_latents, sample_text_embeddings = get_sample_inputs(sample_data)
sample_latents = sample_latents.to(dtype)
# Инициализируем латенты случайным шумом
sample_latents = torch.randn(
sample_latents.shape,
generator=gen,
device=sample_latents.device,
dtype=sample_latents.dtype
)
# Добавляем CFG для лучшего качества
guidance_scale = 3.0
# Генерация изображений
with torch.no_grad():
for t in scheduler.timesteps:
latent_model_input = scheduler.scale_model_input(sample_latents, t)
noise_pred = original_model(latent_model_input, t, sample_text_embeddings).sample
# Шаг диффузии
sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample
# Декодирование через VAE
latent = (sample_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
latent = latent.to(accelerator.device, dtype=dtype)
decoded = vae.decode(latent).sample
# Преобразуем тензоры в PIL-изображения
generated_images = []
for img_idx, img_tensor in enumerate(decoded):
img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
pil_img = Image.fromarray((img * 255).astype("uint8"))
generated_images.append(pil_img)
save_path = f"{generated_folder}/{project}_{img_idx}.jpg"
pil_img.save(save_path, "JPEG", quality=95)
# Отправляем изображения на WandB
if use_wandb and accelerator.is_main_process:
wandb_images = [wandb.Image(img, caption=f"Sample {i}") for i, img in enumerate(generated_images)]
wandb.log({"generated_images": wandb_images, "global_step": step})
finally:
# Гарантированное перемещение VAE обратно на CPU
vae.to("cpu")
if original_model is not None:
del original_model
torch.cuda.empty_cache() # Очищаем кэш CUDA
gc.collect()
# --------------------------- Генерация сэмплов перед обучением ---------------------------
if accelerator.is_main_process:
print("Генерация сэмплов до старта обучения...")
generate_and_save_samples(step=0)
# --------------------------- Тренировочный цикл ---------------------------
# Для логирования среднего лосса каждые 10% эпохи
if accelerator.is_main_process:
print(f"Total steps per GPU: {total_training_steps}")
print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}")
epoch_loss_points = []
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
# Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
steps_per_epoch = len(dataloader)
sample_interval = max(1, steps_per_epoch // 4)
def save_ckpt():
if saveckpt:
# После сохранения модели в конце каждой эпохи
checkpoint_path = os.path.join(checkpoints_folder, checkpoint_file)
torch.save({
'model_state_dict': accelerator.unwrap_model(unet).state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
#'lr_scheduler_state_dict': lr_scheduler.state_dict(),
'epoch': epoch,
'global_step': global_step,
}, checkpoint_path)
print(f"Полный чекпоинт сохранён: {checkpoint_path}")
# Начинаем с указанной эпохи (полезно при возобновлении)
for epoch in range(start_epoch, start_epoch + num_epochs):
batch_losses = []
unet.train()
for step, (latents, embeddings) in enumerate(dataloader):
with accelerator.accumulate(unet):
# Forward pass - используем timesteps из более широкого диапазона
noise = torch.randn_like(latents)
timesteps = torch.randint(
1, # Начинаем с 1, не с 0
scheduler.config.num_train_timesteps,
(latents.shape[0],),
device=device
).long()
# Добавляем шум к латентам
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
# Получаем предсказание шума
noise_pred = unet(noisy_latents, timesteps, embeddings).sample#.to(dtype=torch.bfloat16)
# Используем целевое значение v_prediction
target = scheduler.get_velocity(latents, noise, timesteps)
#print("noise_pred",noise_pred.dtype) # Должно быть torch.bfloat16
#print("target",target.dtype) # Должно быть torch.bfloat16
# Считаем лосс в BF16 для стабильности
loss = torch.nn.functional.mse_loss(
noise_pred.float(),
target.float()
)
#print("loss",loss.dtype)
# Делаем backward через Accelerator
accelerator.backward(loss) # так делал раньше
# Используем ограничение нормы градиентов через Accelerator
grad_norm = accelerator.clip_grad_norm_(unet.parameters(), 1.0)
# Важно: шаг оптимизатора должен быть именно здесь
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Увеличиваем счетчик глобальных шагов
global_step += 1
# Обновляем прогресс-бар
progress_bar.update(1)
# Логируем метрики
if accelerator.is_main_process:
current_lr = lr_scheduler.get_last_lr()[0]
batch_losses.append(loss.detach().item())
# Логируем в Wandb
if global_step % 2 == 0 and use_wandb:
wandb.log({
"loss": loss.detach().item(),
"learning_rate": current_lr,
"epoch": epoch,
"grad_norm": grad_norm.item(),
"epoch": epoch,
"global_step": global_step
})
# Генерируем сэмплы с заданным интервалом
if global_step % sample_interval == 0 and accelerator.is_main_process:
generate_and_save_samples(global_step)
# Выводим текущий лосс
avg_loss = np.mean(batch_losses[-sample_interval:])
#print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}")
if use_wandb:
wandb.log({"intermediate_loss": avg_loss})
# По окончании эпохи
accelerator.wait_for_everyone()
# Сохраняем чекпоинт в конце каждой эпохи
if accelerator.is_main_process:
save_ckpt()
# Сохраняем UNet отдельно для удобства использования
if save_model:
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
avg_epoch_loss = np.mean(batch_losses)
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
if use_wandb:
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
# Завершение обучения - сохраняем финальную модель
if accelerator.is_main_process:
print("Обучение завершено! Сохраняем финальную модель...")
# Сохраняем основную модель
if save_model:
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
print("Готово!")