Diffusers
Safetensors
sdxs / train_mnist.py
recoilme's picture
1b
5311f33
import os
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from diffusers import UNet2DConditionModel, AutoencoderKL
from accelerate import Accelerator
from datasets import load_from_disk
from tqdm import tqdm
from PIL import Image
import wandb
import random
# --------------------------- Параметры ---------------------------
save_path = "datasets/mnist"
batch_size = 320
base_learning_rate = 5e-5
num_epochs = 10
gradient_accumulation_steps = 1
project = "sdxs"
use_wandb = True
limit = 0
grad_clip = 0.1
# Параметры для диффузии
n_diffusion_steps = 20 # число шагов в цепочке диффузии (при генерации)
samples_to_generate = 6
# Папки для сохранения результатов
generated_folder = "samples"
checkpoints_folder = ""
os.makedirs(generated_folder, exist_ok=True)
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
print("init")
# --------------------------- Инициализация Accelerator ---------------------------
dtype = torch.bfloat16
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision="bf16" if dtype == torch.bfloat16 else "no")
device = accelerator.device#"cuda" if torch.cuda.is_available() else "cpu"
# --------------------------- Инициализация WandB ---------------------------
if use_wandb and accelerator.is_main_process:
wandb.init(project=project, config={
"batch_size": batch_size,
"base_learning_rate": base_learning_rate,
"num_epochs": num_epochs,
"gradient_accumulation_steps": gradient_accumulation_steps,
"n_diffusion_steps": n_diffusion_steps,
"samples_to_generate": samples_to_generate,
"dtype": str(dtype)
})
# --------------------------- Загрузка датасета ---------------------------
if limit>0:
dataset = load_from_disk(save_path).select(range(limit))
else:
dataset = load_from_disk(save_path)
def collate_fn(batch):
# Преобразуем список в тензоры и перемещаем на девайс
latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
return latents, embeddings
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dataloader = accelerator.prepare(dataloader)
# --------------------------- Загрузка моделей ---------------------------
# VAE загружается на CPU для экономии GPU-памяти
vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype)
# Загружаем UNet (на устройство)
from diffusers import DDPMScheduler
# DDPMScheduler с V_Prediction и Zero-SNR
scheduler = DDPMScheduler(#DDPMScheduler(
num_train_timesteps=1000, # Полный график шагов для обучения
prediction_type="v_prediction", # V-Prediction
rescale_betas_zero_snr=True # Включение Zero-SNR
)
# Опция загрузки модели из последнего чекпоинта (если существует)
latest_checkpoint = os.path.join(checkpoints_folder, project)#check
if os.path.isdir(latest_checkpoint):
print("Загружаем UNet из чекпоинта:", latest_checkpoint)
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype)
# Подготовка через Accelerator
unet = accelerator.prepare(unet)
# Расчёт общего количества шагов (учитывая gradient_accumulation_steps)
total_training_steps = len(dataloader) * num_epochs // gradient_accumulation_steps
# Get the world size
world_size = accelerator.state.num_processes
print(f"World Size: {world_size}")
# --------------------------- Оптимизатор и кастомный LR scheduler ---------------------------
optimizer = torch.optim.AdamW(unet.parameters(), lr=base_learning_rate)
def lr_schedule(step, max_steps, base_lr):
# Прогрев до 10% от max_steps, затем косинусное затухание до 0.1 * base_lr
x = step / max_steps
if x < 0.1:
return 0.1 * base_lr + 0.9 * base_lr * (x / 0.1)
else:
return 0.1 * base_lr + 0.9 * base_lr * (1 + math.cos(math.pi * (x - 0.1) / (1 - 0.1))) / 2
def custom_lr_lambda(step):
return lr_schedule(step, total_training_steps*world_size, base_learning_rate) / base_learning_rate
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler)
# --------------------------- Фиксированные семплы для генерации ---------------------------
# Выбираем фиксированные индексы (приводим к int)
sample_indices = np.random.choice(len(dataset), samples_to_generate, replace=False)
sample_data = [dataset[int(i)] for i in sample_indices]
def get_sample_inputs(sample_data):
# Получаем latents и текстовые эмбеддинги для выбранных примеров (collate_fn переносит на device)
return collate_fn(sample_data)
@torch.no_grad()
def generate_and_save_samples(step: int):
"""
Перемещает VAE на device, генерирует семплы с заданным числом diffusion шагов,
декодирует их и отправляет на WandB по одному изображению.
После генерации VAE возвращается на CPU.
"""
try:
# Перемещаем VAE на device для семплирования
vae.to(accelerator.device, dtype=dtype)
# Устанавливаем количество diffusion шагов
scheduler.set_timesteps(n_diffusion_steps)
# Получаем фиксированные данные
sample_latents, sample_text_embeddings = get_sample_inputs(sample_data)
sample_latents = sample_latents.to(dtype)
# Инициализируем латенты случайным шумом
sample_latents = torch.randn_like(sample_latents)
with torch.no_grad():
for t in scheduler.timesteps:
latent_model_input = scheduler.scale_model_input(sample_latents, t)
noise_pred = unet(latent_model_input, t, sample_text_embeddings).sample
sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample
#decoded = vae.decode(sample_latents).sample
latent = (sample_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
latent = latent.to(accelerator.device, dtype=dtype)
decoded = vae.decode(latent).sample
# Преобразуем тензоры в PIL-изображения
generated_images = []
for img_idx, img_tensor in enumerate(decoded):
img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
pil_img = Image.fromarray((img * 255).astype("uint8"))
generated_images.append(pil_img)
save_path = f"{generated_folder}/{project}{img_idx}.jpg"
pil_img.save(save_path, "JPEG", quality=95)
# Отправляем изображения на WandB
if use_wandb:
wandb_images = [wandb.Image(img, caption=f"Sample {i}") for i, img in enumerate(generated_images)]
wandb.log({"generated_images": wandb_images, "global_step": step})
finally:
# Гарантированное перемещение VAE обратно на CPU
vae.to("cpu")
torch.cuda.empty_cache() # Очищаем кэш CUDA
# --------------------------- Генерация сэмплов перед обучением ---------------------------
if accelerator.is_main_process:
print("Генерация сэмплов до старта обучения...")
generate_and_save_samples(step=0)
# --------------------------- Тренировочный цикл ---------------------------
global_step = 0
# Для логирования среднего лосса каждые 10% эпохи
if accelerator.is_main_process:
print(f"Total steps per GPU: {total_training_steps}")
print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}")
epoch_loss_points = []
progress_bar = tqdm(total=total_training_steps,disable=not accelerator.is_local_main_process, desc="Training", unit="step")
# Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
steps_per_epoch = len(dataloader)// gradient_accumulation_steps
sample_interval = max(1, steps_per_epoch // 10)
for epoch in range(num_epochs):
batch_losses = []
unet.train()
for step, (latents, embeddings) in enumerate(dataloader):
#optimizer.zero_grad()
with accelerator.accumulate(unet):
# Forward pass
noise = torch.randn_like(latents)
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()#long?
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
noise_pred = unet(noisy_latents, timesteps, embeddings).sample
target = scheduler.get_velocity(latents, noise, timesteps)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = torch.nn.functional.mse_loss(noise_pred, target)
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), grad_clip)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
batch_losses.append(loss.item())
global_step += 1
progress_bar.update(1)
# Логирование в W&B
if accelerator.is_main_process:
if use_wandb:
wandb.log({"loss": loss.item(), "learning_rate": optimizer.param_groups[0]["lr"], "epoch": epoch})
# Логирование каждые sample_interval шагов (примерно 10% эпохи)
if (step + 1) % sample_interval == 0:
avg_loss = np.mean(batch_losses[-sample_interval:])
if use_wandb:
wandb.log({"intermediate_loss": avg_loss, "grad_norm": grad_norm.item()})
# Также проводим генерацию сэмплов
unet.eval()
generate_and_save_samples(step=global_step)
unet.train()
# Средний лосс за эпоху
accelerator.wait_for_everyone() # Дождаться завершения всех процессов
if accelerator.is_main_process:
print(f"[GPU {accelerator.process_index}] Max memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
avg_epoch_loss = np.mean(batch_losses)
epoch_loss_points.append(avg_epoch_loss)
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}")
if use_wandb:
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
# Сохранение модели после каждой эпохи (перезаписываем последний чекпоинт)
checkpoint_path = os.path.join(checkpoints_folder, project)
# Получаем оригинальную модель из DistributedDataParallel
original_model = unet.module if hasattr(unet, 'module') else unet
original_model.save_pretrained(checkpoint_path)
print(f"Чекпоинт сохранён: {checkpoint_path}")
progress_bar.close()
# Генерация сэмплов по окончании обучения
accelerator.wait_for_everyone()
if accelerator.is_main_process:
print("Генерация сэмплов после окончания обучения...")
generate_and_save_samples(step=global_step)
# --------------------------- Построение графика лосса ---------------------------
plt.figure()
plt.plot(np.arange(1, num_epochs+1), epoch_loss_points, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.title("Training Loss Curve")
plt.grid()
plt.savefig("training_loss.png", dpi=300)
print("Training loss curve saved to training_loss.png")
# --------------------------- Сохранение итоговой модели ---------------------------
checkpoint_path = os.path.join(checkpoints_folder, project)
original_model = unet.module if hasattr(unet, 'module') else unet
original_model.save_pretrained(checkpoint_path)
print("Итоговая модель сохранена в", checkpoint_path)
if use_wandb:
wandb.log({"training_loss_curve": wandb.Image("training_loss.png")})
wandb.save(checkpoint_path)
wandb.finish()
# --------------------------- Очистка памяти ---------------------------
del unet, vae, optimizer, lr_scheduler, dataloader
torch.cuda.empty_cache()
print("Память очищена.")